qninhdt commited on
Commit
6ef7ab3
1 Parent(s): 798fdd3
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  __pycache__
2
  datasets
 
3
  *.zip
 
1
  __pycache__
2
  datasets
3
+ wandb
4
  *.zip
configs/autoencoder/autoencoder_kl_32x32x4.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: swim.models.autoencoder.AutoencoderKL
4
+ params:
5
+ monitor: "val/rec_loss"
6
+ embed_dim: 4
7
+ lossconfig:
8
+ target: swim.modules.losses.LPIPSWithDiscriminator
9
+ params:
10
+ disc_start: 50001
11
+ kl_weight: 0.000001
12
+ disc_weight: 0.5
13
+
14
+ ddconfig:
15
+ double_z: True
16
+ z_channels: 4
17
+ resolution: 512
18
+ in_channels: 3
19
+ out_ch: 3
20
+ ch: 128
21
+ ch_mult: [1, 2, 4, 4] # num_down = len(ch_mult)-1
22
+ num_res_blocks: 2
23
+ attn_resolutions: []
24
+ dropout: 0.0
lightning_logs/version_0/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
lightning_logs/version_1/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
lightning_logs/version_2/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
lightning_logs/version_3/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
lightning_logs/version_4/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
lightning_logs/version_5/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
lightning_logs/version_6/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
lightning_logs/version_6/metrics.csv ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ aeloss_step,discloss_step,epoch,step,train/d_weight,train/disc_factor,train/disc_loss,train/g_loss,train/kl_loss,train/logits_fake,train/logits_real,train/logvar,train/nll_loss,train/rec_loss,train/total_loss
2
+ 2953.825927734375,0.0,0,49,247.98703002929688,0.0,0.0,0.4574800729751587,3.8113327026367188,-0.4574800729751587,0.25408393144607544,0.0,2953.825927734375,0.9615318775177002,2953.825927734375
3
+ 2997.572021484375,0.0,0,99,582.9641723632812,0.0,0.0,0.30489930510520935,9.548039436340332,-0.30489930510520935,-0.17654460668563843,0.0,2997.572021484375,0.9757721424102783,2997.572021484375
4
+ 3229.461181640625,0.0,0,149,153.92608642578125,0.0,0.0,0.2801606059074402,38.64592742919922,-0.2801606059074402,0.49510449171066284,0.0,3229.461181640625,1.0512568950653076,3229.461181640625
5
+ 2391.34716796875,0.0,0,199,171.67630004882812,0.0,0.0,0.3865272104740143,29.355663299560547,-0.3865272104740143,-0.2872551679611206,0.0,2391.34716796875,0.7784333229064941,2391.34716796875
6
+ 2008.2764892578125,0.0,0,249,223.92453002929688,0.0,0.0,0.1616521179676056,75.05839538574219,-0.1616521179676056,-0.43370532989501953,0.0,2008.2763671875,0.6537358164787292,2008.2764892578125
7
+ 2754.87353515625,0.0,0,299,370.9005432128906,0.0,0.0,0.08434182405471802,16.720157623291016,-0.08434182405471802,-0.3120231032371521,0.0,2754.87353515625,0.8967687487602234,2754.87353515625
8
+ 1949.1402587890625,0.0,0,349,579.2057495117188,0.0,0.0,0.4615066647529602,88.49298095703125,-0.4615066647529602,-0.39516788721084595,0.0,1949.14013671875,0.6344857215881348,1949.1402587890625
main.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from omegaconf import OmegaConf
3
+ from swim.utils import instantiate_from_config
4
+ from torchinfo import summary
5
+ from swim.modules.dataset import SwimDataModule
6
+ from lightning import Trainer
7
+ from lightning.pytorch.loggers import WandbLogger
8
+
9
+ config = OmegaConf.load("configs/autoencoder/autoencoder_kl_32x32x4.yaml")
10
+
11
+ model = instantiate_from_config(config.model)
12
+ model.learning_rate = config.model.base_learning_rate
13
+
14
+ datamodule = SwimDataModule(img_size=32)
15
+
16
+ logger = WandbLogger(project="swim", name="autoencoder_kl")
17
+
18
+ trainer = Trainer(max_epochs=10, devices=[0], logger=logger, log_every_n_steps=10)
19
+ trainer.fit(model, datamodule)
swim/attention_blocks.py DELETED
@@ -1,315 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import nn
6
-
7
-
8
- class SpatialTransformer(nn.Module):
9
- """
10
- ## Spatial Transformer
11
- """
12
-
13
- def __init__(self, channels: int, n_heads: int, n_layers: int, d_cond: int):
14
- """
15
- :param channels: is the number of channels in the feature map
16
- :param n_heads: is the number of attention heads
17
- :param n_layers: is the number of transformer layers
18
- :param d_cond: is the size of the conditional embedding
19
- """
20
- super().__init__()
21
- # Initial group normalization
22
- self.norm = torch.nn.GroupNorm(
23
- num_groups=32, num_channels=channels, eps=1e-6, affine=True
24
- )
25
- # Initial $1 \times 1$ convolution
26
- self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
27
-
28
- # Transformer layers
29
- self.transformer_blocks = nn.ModuleList(
30
- [
31
- BasicTransformerBlock(
32
- channels, n_heads, channels // n_heads, d_cond=d_cond
33
- )
34
- for _ in range(n_layers)
35
- ]
36
- )
37
-
38
- # Final $1 \times 1$ convolution
39
- self.proj_out = nn.Conv2d(
40
- channels, channels, kernel_size=1, stride=1, padding=0
41
- )
42
-
43
- def forward(self, x: torch.Tensor, cond: torch.Tensor):
44
- """
45
- :param x: is the feature map of shape `[batch_size, channels, height, width]`
46
- :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
47
- """
48
- # Get shape `[batch_size, channels, height, width]`
49
- b, c, h, w = x.shape
50
- # For residual connection
51
- x_in = x
52
- # Normalize
53
- x = self.norm(x)
54
- # Initial $1 \times 1$ convolution
55
- x = self.proj_in(x)
56
- # Transpose and reshape from `[batch_size, channels, height, width]`
57
- # to `[batch_size, height * width, channels]`
58
- x = x.permute(0, 2, 3, 1).view(b, h * w, c)
59
- # Apply the transformer layers
60
- for block in self.transformer_blocks:
61
- x = block(x, cond)
62
- # Reshape and transpose from `[batch_size, height * width, channels]`
63
- # to `[batch_size, channels, height, width]`
64
- x = x.view(b, h, w, c).permute(0, 3, 1, 2)
65
- # Final $1 \times 1$ convolution
66
- x = self.proj_out(x)
67
- # Add residual
68
- return x + x_in
69
-
70
-
71
- class BasicTransformerBlock(nn.Module):
72
- """
73
- ### Transformer Layer
74
- """
75
-
76
- def __init__(self, d_model: int, n_heads: int, d_head: int, d_cond: int):
77
- """
78
- :param d_model: is the input embedding size
79
- :param n_heads: is the number of attention heads
80
- :param d_head: is the size of a attention head
81
- :param d_cond: is the size of the conditional embeddings
82
- """
83
- super().__init__()
84
- # Self-attention layer and pre-norm layer
85
- self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
86
- self.norm1 = nn.LayerNorm(d_model)
87
- # Cross attention layer and pre-norm layer
88
- self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
89
- self.norm2 = nn.LayerNorm(d_model)
90
- # Feed-forward network and pre-norm layer
91
- self.ff = FeedForward(d_model)
92
- self.norm3 = nn.LayerNorm(d_model)
93
-
94
- def forward(self, x: torch.Tensor, cond: torch.Tensor):
95
- """
96
- :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
97
- :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
98
- """
99
- # Self attention
100
- x = self.attn1(self.norm1(x)) + x
101
- # Cross-attention with conditioning
102
- x = self.attn2(self.norm2(x), cond=cond) + x
103
- # Feed-forward network
104
- x = self.ff(self.norm3(x)) + x
105
- #
106
- return x
107
-
108
-
109
- class CrossAttention(nn.Module):
110
- """
111
- ### Cross Attention Layer
112
-
113
- This falls-back to self-attention when conditional embeddings are not specified.
114
- """
115
-
116
- use_flash_attention: bool = False
117
-
118
- def __init__(
119
- self,
120
- d_model: int,
121
- d_cond: int,
122
- n_heads: int,
123
- d_head: int,
124
- is_inplace: bool = True,
125
- ):
126
- """
127
- :param d_model: is the input embedding size
128
- :param n_heads: is the number of attention heads
129
- :param d_head: is the size of a attention head
130
- :param d_cond: is the size of the conditional embeddings
131
- :param is_inplace: specifies whether to perform the attention softmax computation inplace to
132
- save memory
133
- """
134
- super().__init__()
135
-
136
- self.is_inplace = is_inplace
137
- self.n_heads = n_heads
138
- self.d_head = d_head
139
-
140
- # Attention scaling factor
141
- self.scale = d_head**-0.5
142
-
143
- # Query, key and value mappings
144
- d_attn = d_head * n_heads
145
- self.to_q = nn.Linear(d_model, d_attn, bias=False)
146
- self.to_k = nn.Linear(d_cond, d_attn, bias=False)
147
- self.to_v = nn.Linear(d_cond, d_attn, bias=False)
148
-
149
- # Final linear layer
150
- self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))
151
-
152
- # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
153
- # Flash attention is only used if it's installed
154
- # and `CrossAttention.use_flash_attention` is set to `True`.
155
- # try:
156
- # # You can install flash attention by cloning their Github repo,
157
- # # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
158
- # # and then running `python setup.py install`
159
- # from flash_attn.flash_attention import FlashAttention
160
-
161
- # self.flash = FlashAttention()
162
- # # Set the scale for scaled dot-product attention.
163
- # self.flash.softmax_scale = self.scale
164
- # # Set to `None` if it's not installed
165
- # except ImportError:
166
- # self.flash = None
167
-
168
- def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
169
- """
170
- :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
171
- :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
172
- """
173
-
174
- # If `cond` is `None` we perform self attention
175
- has_cond = cond is not None
176
- if not has_cond:
177
- cond = x
178
-
179
- # Get query, key and value vectors
180
- q = self.to_q(x)
181
- k = self.to_k(cond)
182
- v = self.to_v(cond)
183
-
184
- # Use flash attention if it's available and the head size is less than or equal to `128`
185
- if (
186
- CrossAttention.use_flash_attention
187
- and self.flash is not None
188
- and not has_cond
189
- and self.d_head <= 128
190
- ):
191
- return self.flash_attention(q, k, v)
192
- # Otherwise, fallback to normal attention
193
- else:
194
- return self.normal_attention(q, k, v)
195
-
196
- def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
197
- """
198
- #### Flash Attention
199
-
200
- :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
201
- :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
202
- :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
203
- """
204
-
205
- # Get batch size and number of elements along sequence axis (`width * height`)
206
- batch_size, seq_len, _ = q.shape
207
-
208
- # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
209
- # shape `[batch_size, seq_len, 3, n_heads * d_head]`
210
- qkv = torch.stack((q, k, v), dim=2)
211
- # Split the heads
212
- qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
213
-
214
- # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
215
- # fit this size.
216
- if self.d_head <= 32:
217
- pad = 32 - self.d_head
218
- elif self.d_head <= 64:
219
- pad = 64 - self.d_head
220
- elif self.d_head <= 128:
221
- pad = 128 - self.d_head
222
- else:
223
- raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
224
-
225
- # Pad the heads
226
- if pad:
227
- qkv = torch.cat(
228
- (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
229
- )
230
-
231
- # Compute attention
232
- # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
233
- # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
234
- out, _ = self.flash(qkv)
235
- # Truncate the extra head size
236
- out = out[:, :, :, : self.d_head]
237
- # Reshape to `[batch_size, seq_len, n_heads * d_head]`
238
- out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
239
-
240
- # Map to `[batch_size, height * width, d_model]` with a linear layer
241
- return self.to_out(out)
242
-
243
- def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
244
- """
245
- #### Normal Attention
246
-
247
- :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
248
- :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
249
- :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
250
- """
251
-
252
- # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
253
- q = q.view(*q.shape[:2], self.n_heads, -1)
254
- k = k.view(*k.shape[:2], self.n_heads, -1)
255
- v = v.view(*v.shape[:2], self.n_heads, -1)
256
-
257
- # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
258
- attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
259
-
260
- # Compute softmax
261
- # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
262
- if self.is_inplace:
263
- half = attn.shape[0] // 2
264
- attn[half:] = attn[half:].softmax(dim=-1)
265
- attn[:half] = attn[:half].softmax(dim=-1)
266
- else:
267
- attn = attn.softmax(dim=-1)
268
-
269
- # Compute attention output
270
- # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
271
- out = torch.einsum("bhij,bjhd->bihd", attn, v)
272
- # Reshape to `[batch_size, height * width, n_heads * d_head]`
273
- out = out.reshape(*out.shape[:2], -1)
274
- # Map to `[batch_size, height * width, d_model]` with a linear layer
275
- return self.to_out(out)
276
-
277
-
278
- class FeedForward(nn.Module):
279
- """
280
- ### Feed-Forward Network
281
- """
282
-
283
- def __init__(self, d_model: int, d_mult: int = 4):
284
- """
285
- :param d_model: is the input embedding size
286
- :param d_mult: is multiplicative factor for the hidden layer size
287
- """
288
- super().__init__()
289
- self.net = nn.Sequential(
290
- GeGLU(d_model, d_model * d_mult),
291
- nn.Dropout(0.0),
292
- nn.Linear(d_model * d_mult, d_model),
293
- )
294
-
295
- def forward(self, x: torch.Tensor):
296
- return self.net(x)
297
-
298
-
299
- class GeGLU(nn.Module):
300
- """
301
- ### GeGLU Activation
302
-
303
- $$\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$$
304
- """
305
-
306
- def __init__(self, d_in: int, d_out: int):
307
- super().__init__()
308
- # Combined linear projections $xW + b$ and $xV + c$
309
- self.proj = nn.Linear(d_in, d_out * 2)
310
-
311
- def forward(self, x: torch.Tensor):
312
- # Get $xW + b$ and $xV + c$
313
- x, gate = self.proj(x).chunk(2, dim=-1)
314
- # $\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$
315
- return x * F.gelu(gate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
swim/autoencoder.py DELETED
@@ -1,247 +0,0 @@
1
- from typing import List
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import nn
6
-
7
- from .blocks import (
8
- ResnetBlock,
9
- AttentionBlock,
10
- GroupNorm,
11
- UpSampleBlock,
12
- DownSampleBlock,
13
- )
14
-
15
-
16
- class Autoencoder(nn.Module):
17
-
18
- def __init__(
19
- self,
20
- channels: int,
21
- channel_multipliers: List[int],
22
- n_resnet_blocks: int,
23
- in_channels: int,
24
- z_channels: int,
25
- emb_channels: int,
26
- ):
27
- super().__init__()
28
- self.encoder = Encoder(
29
- channels=channels,
30
- channel_multipliers=channel_multipliers,
31
- n_resnet_blocks=n_resnet_blocks,
32
- in_channels=in_channels,
33
- z_channels=z_channels,
34
- )
35
- self.decoder = Decoder(
36
- channels=channels,
37
- channel_multipliers=channel_multipliers,
38
- n_resnet_blocks=n_resnet_blocks,
39
- out_channels=in_channels,
40
- z_channels=z_channels,
41
- )
42
- # Convolution to map from embedding space to
43
- # quantized embedding space moments (mean and log variance)
44
- self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)
45
- # Convolution to map from quantized embedding space back to
46
- # embedding space
47
- self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)
48
-
49
- def encode(self, img: torch.Tensor) -> "GaussianDistribution":
50
- # Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]`
51
- z = self.encoder(img)
52
- # Get the moments in the quantized embedding space
53
- moments = self.quant_conv(z)
54
- # Return the distribution
55
- return GaussianDistribution(moments)
56
-
57
- def decode(self, z: torch.Tensor):
58
- # Map to embedding space from the quantized representation
59
- z = self.post_quant_conv(z)
60
- # Decode the image of shape `[batch_size, channels, height, width]`
61
- return self.decoder(z)
62
-
63
- def forward(self, x: torch.Tensor, sample_posterior: bool = False):
64
- posterior = self.encode(x)
65
- if sample_posterior:
66
- z = posterior.sample()
67
- else:
68
- z = posterior.mode()
69
- decoded_x = self.decode(z)
70
- return decoded_x, posterior
71
-
72
-
73
- class Encoder(nn.Module):
74
- def __init__(
75
- self,
76
- *,
77
- channels: int,
78
- channel_multipliers: List[int],
79
- n_resnet_blocks: int,
80
- in_channels: int,
81
- z_channels: int
82
- ):
83
- super().__init__()
84
-
85
- # Number of blocks of different resolutions.
86
- # The resolution is halved at the end each top level block
87
- n_resolutions = len(channel_multipliers)
88
-
89
- # Initial $3 \times 3$ convolution layer that maps the image to `channels`
90
- self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)
91
-
92
- # Number of channels in each top level block
93
- channels_list = [m * channels for m in [1] + channel_multipliers]
94
-
95
- # List of top-level blocks
96
- self.down = nn.ModuleList()
97
- # Create top-level blocks
98
- for i in range(n_resolutions):
99
- # Each top level block consists of multiple ResNet Blocks and down-sampling
100
- resnet_blocks = nn.ModuleList()
101
- # Add ResNet Blocks
102
- for _ in range(n_resnet_blocks):
103
- resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
104
- channels = channels_list[i + 1]
105
- # Top-level block
106
- down = nn.Module()
107
- down.block = resnet_blocks
108
- # Down-sampling at the end of each top level block except the last
109
- if i != n_resolutions - 1:
110
- down.downsample = DownSampleBlock(channels)
111
- else:
112
- down.downsample = nn.Identity()
113
- #
114
- self.down.append(down)
115
-
116
- # Final ResNet blocks with attention
117
- self.mid = nn.Module()
118
- self.mid.block_1 = ResnetBlock(channels, channels)
119
- self.mid.attn_1 = AttentionBlock(channels)
120
- self.mid.block_2 = ResnetBlock(channels, channels)
121
-
122
- # Map to embedding space with a $3 \times 3$ convolution
123
- self.norm_out = GroupNorm(channels)
124
- self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1)
125
-
126
- def forward(self, img: torch.Tensor):
127
- # Map to `channels` with the initial convolution
128
- x = self.conv_in(img)
129
-
130
- # Top-level blocks
131
- for down in self.down:
132
- # ResNet Blocks
133
- for block in down.block:
134
- x = block(x)
135
- # Down-sampling
136
- x = down.downsample(x)
137
-
138
- # Final ResNet blocks with attention
139
- x = self.mid.block_1(x)
140
- x = self.mid.attn_1(x)
141
- x = self.mid.block_2(x)
142
-
143
- # Normalize and map to embedding space
144
- x = self.norm_out(x)
145
- x = F.silu(x)
146
- x = self.conv_out(x)
147
-
148
- return x
149
-
150
-
151
- class Decoder(nn.Module):
152
-
153
- def __init__(
154
- self,
155
- *,
156
- channels: int,
157
- channel_multipliers: List[int],
158
- n_resnet_blocks: int,
159
- out_channels: int,
160
- z_channels: int
161
- ):
162
- super().__init__()
163
-
164
- # Number of blocks of different resolutions.
165
- # The resolution is halved at the end each top level block
166
- num_resolutions = len(channel_multipliers)
167
-
168
- # Number of channels in each top level block, in the reverse order
169
- channels_list = [m * channels for m in channel_multipliers]
170
-
171
- # Number of channels in the top-level block
172
- channels = channels_list[-1]
173
-
174
- # Initial $3 \times 3$ convolution layer that maps the embedding space to `channels`
175
- self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1)
176
-
177
- # ResNet blocks with attention
178
- self.mid = nn.Module()
179
- self.mid.block_1 = ResnetBlock(channels, channels)
180
- self.mid.attn_1 = AttentionBlock(channels)
181
- self.mid.block_2 = ResnetBlock(channels, channels)
182
-
183
- # List of top-level blocks
184
- self.up = nn.ModuleList()
185
- # Create top-level blocks
186
- for i in reversed(range(num_resolutions)):
187
- # Each top level block consists of multiple ResNet Blocks and up-sampling
188
- resnet_blocks = nn.ModuleList()
189
- # Add ResNet Blocks
190
- for _ in range(n_resnet_blocks + 1):
191
- resnet_blocks.append(ResnetBlock(channels, channels_list[i]))
192
- channels = channels_list[i]
193
- # Top-level block
194
- up = nn.Module()
195
- up.block = resnet_blocks
196
- # Up-sampling at the end of each top level block except the first
197
- if i != 0:
198
- up.upsample = UpSampleBlock(channels)
199
- else:
200
- up.upsample = nn.Identity()
201
- # Prepend to be consistent with the checkpoint
202
- self.up.insert(0, up)
203
-
204
- # Map to image space with a $3 \times 3$ convolution
205
- self.norm_out = GroupNorm(channels)
206
- self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)
207
-
208
- def forward(self, z: torch.Tensor):
209
- # Map to `channels` with the initial convolution
210
- h = self.conv_in(z)
211
-
212
- # ResNet blocks with attention
213
- h = self.mid.block_1(h)
214
- h = self.mid.attn_1(h)
215
- h = self.mid.block_2(h)
216
-
217
- # Top-level blocks
218
- for up in reversed(self.up):
219
- # ResNet Blocks
220
- for block in up.block:
221
- h = block(h)
222
- # Up-sampling
223
- h = up.upsample(h)
224
-
225
- # Normalize and map to image space
226
- h = self.norm_out(h)
227
- h = F.silu(h)
228
- img = self.conv_out(h)
229
-
230
- return img
231
-
232
-
233
- class GaussianDistribution:
234
- def __init__(self, parameters: torch.Tensor):
235
- # Split mean and log of variance
236
- self.mean, log_var = torch.chunk(parameters, 2, dim=1)
237
- # Clamp the log of variances
238
- self.log_var = torch.clamp(log_var, -30.0, 20.0)
239
- # Calculate standard deviation
240
- self.std = torch.exp(0.5 * self.log_var)
241
-
242
- def sample(self):
243
- # Sample from the distribution
244
- return self.mean + self.std * torch.randn_like(self.std)
245
-
246
- def mode(self):
247
- return self.mean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
swim/blocks.py DELETED
@@ -1,227 +0,0 @@
1
- from abc import abstractmethod
2
-
3
- import math
4
- import torch
5
- from torch import nn
6
- from torch.nn import functional as F
7
-
8
-
9
- def get_timestep_embedding(
10
- timesteps: torch.Tensor, emb_dim: int, max_period: int = 10000
11
- ) -> torch.Tensor:
12
- half_dim = emb_dim // 2
13
-
14
- emb = math.log(max_period) / (half_dim - 1)
15
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
16
- emb = emb.to(device=timesteps.device)
17
-
18
- emb = timesteps.float()[:, None] * emb[None, :]
19
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
20
-
21
- if emb_dim % 2 == 1:
22
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
23
-
24
- return emb
25
-
26
-
27
- class GroupNorm(nn.Module):
28
- def __init__(self, in_channels: int) -> None:
29
- super().__init__()
30
-
31
- self.group_norm = nn.GroupNorm(
32
- num_groups=32, num_channels=in_channels, eps=1e-06, affine=True
33
- )
34
-
35
- def forward(self, x: torch.Tensor) -> torch.Tensor:
36
- return self.group_norm(x)
37
-
38
-
39
- class UpSampleBlock(nn.Module):
40
- def __init__(self, channels: int):
41
- super().__init__()
42
- self.conv = nn.Conv2d(channels, channels, 3, padding=1)
43
-
44
- def forward(self, x: torch.Tensor):
45
- x = F.interpolate(x, scale_factor=2.0, mode="nearest")
46
- return self.conv(x)
47
-
48
-
49
- class DownSampleBlock(nn.Module):
50
- def __init__(self, channels: int):
51
- super().__init__()
52
- self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)
53
-
54
- def forward(self, x: torch.Tensor):
55
- x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)
56
- return self.conv(x)
57
-
58
-
59
- class TimestepBlock(nn.Module):
60
- @abstractmethod
61
- def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
62
- pass
63
-
64
-
65
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
66
- def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
67
- for layer in self:
68
- if isinstance(layer, TimestepBlock):
69
- x = layer(x, t_emb)
70
- else:
71
- x = layer(x)
72
- return x
73
-
74
-
75
- class ResnetBlock(nn.Module):
76
-
77
- def __init__(
78
- self,
79
- in_channels: int,
80
- out_channels: int = None,
81
- t_emb_dim: int = None,
82
- dropout: float = 0.0,
83
- ):
84
- super().__init__()
85
-
86
- if out_channels is None:
87
- out_channels = in_channels
88
-
89
- self.input_layers = nn.Sequential(
90
- GroupNorm(in_channels),
91
- nn.SiLU(),
92
- nn.Conv2d(in_channels, out_channels, 3, padding=1),
93
- )
94
-
95
- if t_emb_dim is not None:
96
- self.t_emb_layers = nn.Sequential(
97
- nn.SiLU(),
98
- nn.Linear(t_emb_dim, out_channels),
99
- )
100
- else:
101
- self.t_emb_layers = None
102
-
103
- self.output_layers = nn.Sequential(
104
- GroupNorm(out_channels),
105
- nn.SiLU(),
106
- nn.Dropout(dropout),
107
- nn.Conv2d(out_channels, out_channels, 3, padding=1),
108
- )
109
-
110
- if in_channels != out_channels:
111
- self.skip = nn.Conv2d(in_channels, out_channels, 1)
112
- else:
113
- self.skip = nn.Identity()
114
-
115
- def forward(self, x: torch.Tensor, t: torch.Tensor = None) -> torch.Tensor:
116
- assert t is not None or self.t_emb_layers is None
117
-
118
- h = self.input_layers(x)
119
-
120
- if self.t_emb_layers is not None:
121
- t_emb = self.t_emb_layers(t)
122
- h = h + t_emb[:, :, None, None]
123
-
124
- h = self.output_layers(h)
125
-
126
- h = h + self.skip(x)
127
-
128
- return h
129
-
130
-
131
- class AttentionBlock(nn.Module):
132
- def __init__(self, in_channels: int) -> None:
133
- super().__init__()
134
-
135
- self.in_channels = in_channels
136
-
137
- # normalization layer
138
- self.norm = GroupNorm(in_channels)
139
-
140
- # query, key and value layers
141
- self.q = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
142
- self.k = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
143
- self.v = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
144
-
145
- self.project_out = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
146
-
147
- self.softmax = nn.Softmax(dim=2)
148
-
149
- def forward(self, x):
150
-
151
- batch, _, height, width = x.size()
152
-
153
- x = self.norm(x)
154
-
155
- # query, key and value layers
156
- q = self.q(x)
157
- k = self.k(x)
158
- v = self.v(x)
159
-
160
- # resizing the output from 4D to 3D to generate attention map
161
- q = q.reshape(batch, self.in_channels, height * width)
162
- k = k.reshape(batch, self.in_channels, height * width)
163
- v = v.reshape(batch, self.in_channels, height * width)
164
-
165
- # transpose the query tensor for dot product
166
- q = q.permute(0, 2, 1)
167
-
168
- # main attention formula
169
- scores = torch.bmm(q, k) * (self.in_channels**-0.5)
170
- weights = self.softmax(scores)
171
- weights = weights.permute(0, 2, 1)
172
-
173
- attention = torch.bmm(v, weights)
174
-
175
- # resizing the output from 3D to 4D to match the input
176
- attention = attention.reshape(batch, self.in_channels, height, width)
177
- attention = self.project_out(attention)
178
-
179
- # adding the identity to the output
180
- return x + attention
181
-
182
-
183
- class AttentionBlock(nn.Module):
184
- def __init__(self, channels: int):
185
- super().__init__()
186
- # Group normalization
187
- self.norm = GroupNorm(channels)
188
- # Query, key and value mappings
189
- self.q = nn.Conv2d(channels, channels, 1)
190
- self.k = nn.Conv2d(channels, channels, 1)
191
- self.v = nn.Conv2d(channels, channels, 1)
192
-
193
- self.proj_out = nn.Conv2d(channels, channels, 1)
194
-
195
- # Attention scaling factor
196
- self.scale = channels**-0.5
197
-
198
- def forward(self, x: torch.Tensor):
199
- # Normalize `x`
200
- x_norm = self.norm(x)
201
- # Get query, key and vector embeddings
202
- q = self.q(x_norm)
203
- k = self.k(x_norm)
204
- v = self.v(x_norm)
205
-
206
- # Reshape to query, key and vector embeedings from
207
- # `[batch_size, channels, height, width]` to
208
- # `[batch_size, channels, height * width]`
209
- b, c, h, w = q.shape
210
- q = q.view(b, c, h * w)
211
- k = k.view(b, c, h * w)
212
- v = v.view(b, c, h * w)
213
-
214
- # Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$
215
- attn = torch.einsum("bci,bcj->bij", q, k) * self.scale
216
- attn = F.softmax(attn, dim=2)
217
-
218
- # Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$
219
- out = torch.einsum("bij,bcj->bci", attn, v)
220
-
221
- # Reshape back to `[batch_size, channels, height, width]`
222
- out = out.view(b, c, h, w)
223
- # Final $1 \times 1$ convolution layer
224
- out = self.proj_out(out)
225
-
226
- # Add residual connection
227
- return x + out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
swim/codeblock.py DELETED
@@ -1,74 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
-
5
- class SwimCodeBook(nn.Module):
6
- def __init__(
7
- self, num_codebook_vectors: int = 1024, latent_dim: int = 256, beta: int = 0.25
8
- ):
9
- super().__init__()
10
-
11
- self.num_codebook_vectors = num_codebook_vectors
12
- self.latent_dim = latent_dim
13
- self.beta = beta
14
-
15
- # creating the codebook, nn.Embedding here is simply a 2D array mainly for storing our embeddings, it's also learnable
16
- self.codebook = nn.Embedding(num_codebook_vectors, latent_dim)
17
-
18
- # Initializing the weights in codebook in uniform distribution
19
- self.codebook.weight.data.uniform_(
20
- -1 / num_codebook_vectors, 1 / num_codebook_vectors
21
- )
22
-
23
- def forward(self, z: torch.Tensor) -> torch.Tensor:
24
- # Channel to last dimension and copying the tensor to store it in a contiguous ( in a sequence ) way
25
- z = z.permute(0, 2, 3, 1).contiguous()
26
-
27
- z_flattened = z.view(
28
- -1, self.latent_dim
29
- ) # b*h*w * latent_dim, will look similar to codebook in fig 2 of the paper
30
-
31
- # calculating the distance between the z to the vectors in flattened codebook, from eq. 2
32
- # (a - b)^2 = a^2 + b^2 - 2ab
33
- distance = (
34
- torch.sum(
35
- z_flattened**2, dim=1, keepdim=True
36
- ) # keepdim = True to keep the same original shape after the sum
37
- + torch.sum(self.codebook.weight**2, dim=1)
38
- - 2
39
- * torch.matmul(
40
- z_flattened, self.codebook.weight.t()
41
- ) # 2*dot(z, codebook.T)
42
- )
43
-
44
- # getting indices of vectors with minimum distance from the codebook
45
- min_distance_indices = torch.argmin(distance, dim=1)
46
-
47
- # getting the corresponding vector from the codebook
48
- z_q = self.codebook(min_distance_indices).view(z.shape)
49
-
50
- """
51
- this represent the equation 4 from the paper ( except the reconstruction loss ) . Thia loss will then be added
52
- to GAN loss to create the final loss function for VQGAN, eq. 6 in the paper.
53
-
54
-
55
- Note : In the first para of A. Changlog section of the paper,
56
- they found a bug which resulted in beta equal to 1. here https://github.com/CompVis/taming-transformers/issues/57
57
- just a note :)
58
- """
59
- loss = torch.mean(
60
- (z_q.detach() - z) ** 2
61
- # detach() to avoid calculating gradient while backpropagating
62
- + self.beta
63
- * torch.mean(
64
- (z_q - z.detach()) ** 2
65
- ) # commitment loss, detach() to avoid calculating gradient while backpropagating
66
- )
67
-
68
- # Not sure why we need this, but it's in the original implementation and mentions for "preserving gradients"
69
- z_q = z + (z_q - z).detach()
70
-
71
- # reshapring to the original shape
72
- z_q = z_q.permute(0, 3, 1, 2)
73
-
74
- return z_q, min_distance_indices, loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
swim/discriminator.py DELETED
@@ -1,45 +0,0 @@
1
- import torch.nn as nn
2
-
3
-
4
- class Discriminator(nn.Module):
5
- """PatchGAN Discriminator
6
-
7
-
8
- Args:
9
- image_channels (int): Number of channels in the input image.
10
- num_filters_last (int): Number of filters in the last layer of the discriminator.
11
- n_layers (int): Number of layers in the discriminator.
12
-
13
-
14
- """
15
-
16
- def __init__(self, image_channels: int = 3, num_filters_last=64, n_layers=3):
17
- super(Discriminator, self).__init__()
18
-
19
- layers = [
20
- nn.Conv2d(image_channels, num_filters_last, 4, 2, 1),
21
- nn.LeakyReLU(0.2),
22
- ]
23
- num_filters_mult = 1
24
-
25
- for i in range(1, n_layers + 1):
26
- num_filters_mult_last = num_filters_mult
27
- num_filters_mult = min(2**i, 8)
28
- layers += [
29
- nn.Conv2d(
30
- num_filters_last * num_filters_mult_last,
31
- num_filters_last * num_filters_mult,
32
- 4,
33
- 2 if i < n_layers else 1,
34
- 1,
35
- bias=False,
36
- ),
37
- nn.BatchNorm2d(num_filters_last * num_filters_mult),
38
- nn.LeakyReLU(0.2, True),
39
- ]
40
-
41
- layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1))
42
- self.model = nn.Sequential(*layers)
43
-
44
- def forward(self, x):
45
- return self.model(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
swim/lr_scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n, **kwargs):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n, **kwargs):
33
+ return self.schedule(n,**kwargs)
34
+
35
+
36
+ class LambdaWarmUpCosineScheduler2:
37
+ """
38
+ supports repeated iterations, configurable via lists
39
+ note: use with a base_lr of 1.0.
40
+ """
41
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
+ self.lr_warm_up_steps = warm_up_steps
44
+ self.f_start = f_start
45
+ self.f_min = f_min
46
+ self.f_max = f_max
47
+ self.cycle_lengths = cycle_lengths
48
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
+ self.last_f = 0.
50
+ self.verbosity_interval = verbosity_interval
51
+
52
+ def find_in_interval(self, n):
53
+ interval = 0
54
+ for cl in self.cum_cycles[1:]:
55
+ if n <= cl:
56
+ return interval
57
+ interval += 1
58
+
59
+ def schedule(self, n, **kwargs):
60
+ cycle = self.find_in_interval(n)
61
+ n = n - self.cum_cycles[cycle]
62
+ if self.verbosity_interval > 0:
63
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64
+ f"current cycle {cycle}")
65
+ if n < self.lr_warm_up_steps[cycle]:
66
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67
+ self.last_f = f
68
+ return f
69
+ else:
70
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71
+ t = min(t, 1.0)
72
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73
+ 1 + np.cos(t * np.pi))
74
+ self.last_f = f
75
+ return f
76
+
77
+ def __call__(self, n, **kwargs):
78
+ return self.schedule(n, **kwargs)
79
+
80
+
81
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88
+ f"current cycle {cycle}")
89
+
90
+ if n < self.lr_warm_up_steps[cycle]:
91
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92
+ self.last_f = f
93
+ return f
94
+ else:
95
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96
+ self.last_f = f
97
+ return f
98
+
swim/models/autoencoder.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from lightning import LightningModule
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+
6
+ from swim.modules.diffusionmodules.model import Encoder, Decoder
7
+ from swim.modules.distributions.distributions import DiagonalGaussianDistribution
8
+
9
+ from swim.utils import instantiate_from_config
10
+
11
+
12
+ class AutoencoderKL(LightningModule):
13
+ def __init__(
14
+ self,
15
+ ddconfig,
16
+ lossconfig,
17
+ embed_dim,
18
+ ckpt_path=None,
19
+ ignore_keys=[],
20
+ monitor=None,
21
+ ):
22
+ super().__init__()
23
+ self.encoder = Encoder(**ddconfig)
24
+ self.decoder = Decoder(**ddconfig)
25
+ self.loss = instantiate_from_config(lossconfig)
26
+ assert ddconfig["double_z"]
27
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
28
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
29
+ self.embed_dim = embed_dim
30
+
31
+ self.automatic_optimization = False
32
+
33
+ if monitor is not None:
34
+ self.monitor = monitor
35
+
36
+ if ckpt_path is not None:
37
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
38
+
39
+ def init_from_ckpt(self, path, ignore_keys=list()):
40
+ sd = torch.load(path, map_location="cpu")["state_dict"]
41
+ keys = list(sd.keys())
42
+ for k in keys:
43
+ for ik in ignore_keys:
44
+ if k.startswith(ik):
45
+ print("Deleting key {} from state_dict.".format(k))
46
+ del sd[k]
47
+ self.load_state_dict(sd, strict=False)
48
+ print(f"Restored from {path}")
49
+
50
+ def encode(self, x):
51
+ h = self.encoder(x)
52
+ moments = self.quant_conv(h)
53
+ posterior = DiagonalGaussianDistribution(moments)
54
+ return posterior
55
+
56
+ def decode(self, z):
57
+ z = self.post_quant_conv(z)
58
+ dec = self.decoder(z)
59
+ return dec
60
+
61
+ def forward(self, input, sample_posterior=True):
62
+ posterior = self.encode(input)
63
+ if sample_posterior:
64
+ z = posterior.sample()
65
+ else:
66
+ z = posterior.mode()
67
+ dec = self.decode(z)
68
+ return dec, posterior
69
+
70
+ def get_input(self, batch, k):
71
+ x = batch[k]
72
+ if len(x.shape) == 3:
73
+ x = x[..., None]
74
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
75
+ return x
76
+
77
+ def training_step(self, batch, batch_idx):
78
+ opt_ae, opt_disc = self.optimizers()
79
+
80
+ # optimize the autoencoder
81
+ reconstructions, posterior = self(batch["images"])
82
+
83
+ ae_loss, log_dict_ae = self.loss(
84
+ batch["images"],
85
+ reconstructions,
86
+ posterior,
87
+ 0,
88
+ self.global_step,
89
+ last_layer=self.get_last_layer(),
90
+ split="train",
91
+ )
92
+
93
+ opt_ae.zero_grad()
94
+ self.manual_backward(ae_loss)
95
+ opt_ae.step()
96
+
97
+ self.log(
98
+ "aeloss",
99
+ ae_loss,
100
+ prog_bar=True,
101
+ logger=True,
102
+ on_step=True,
103
+ on_epoch=True,
104
+ )
105
+
106
+ self.log_dict(
107
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
108
+ )
109
+
110
+ # optimize the discriminator
111
+ reconstructions, posterior = self(batch["images"])
112
+
113
+ disc_loss, log_dict_disc = self.loss(
114
+ batch["images"],
115
+ reconstructions,
116
+ posterior,
117
+ 1,
118
+ self.global_step,
119
+ last_layer=self.get_last_layer(),
120
+ split="train",
121
+ )
122
+
123
+ opt_disc.zero_grad()
124
+ self.manual_backward(disc_loss)
125
+ opt_disc.step()
126
+
127
+ self.log(
128
+ "discloss",
129
+ disc_loss,
130
+ prog_bar=True,
131
+ logger=True,
132
+ on_step=True,
133
+ on_epoch=True,
134
+ )
135
+
136
+ self.log_dict(
137
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
138
+ )
139
+
140
+ def validation_step(self, batch, batch_idx):
141
+ reconstructions, posterior = self(batch["images"])
142
+ aeloss, log_dict_ae = self.loss(
143
+ batch["images"],
144
+ reconstructions,
145
+ posterior,
146
+ 0,
147
+ self.global_step,
148
+ last_layer=self.get_last_layer(),
149
+ split="val",
150
+ )
151
+
152
+ discloss, log_dict_disc = self.loss(
153
+ batch["images"],
154
+ reconstructions,
155
+ posterior,
156
+ 1,
157
+ self.global_step,
158
+ last_layer=self.get_last_layer(),
159
+ split="val",
160
+ )
161
+
162
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
163
+ self.log_dict(log_dict_ae)
164
+ self.log_dict(log_dict_disc)
165
+
166
+ def configure_optimizers(self):
167
+ lr = self.learning_rate
168
+ opt_ae = torch.optim.Adam(
169
+ list(self.encoder.parameters())
170
+ + list(self.decoder.parameters())
171
+ + list(self.quant_conv.parameters())
172
+ + list(self.post_quant_conv.parameters()),
173
+ lr=lr,
174
+ betas=(0.5, 0.9),
175
+ )
176
+ opt_disc = torch.optim.Adam(
177
+ self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
178
+ )
179
+ return [opt_ae, opt_disc], []
180
+
181
+ def get_last_layer(self):
182
+ return self.decoder.conv_out.weight
183
+
184
+ @torch.no_grad()
185
+ def log_images(self, batch, only_inputs=False, **kwargs):
186
+ log = dict()
187
+ x = batch["images"]
188
+ x = x.to(self.device)
189
+ if not only_inputs:
190
+ xrec, posterior = self(x)
191
+ if x.shape[1] > 3:
192
+ # colorize with random projection
193
+ assert xrec.shape[1] > 3
194
+ x = self.to_rgb(x)
195
+ xrec = self.to_rgb(xrec)
196
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
197
+ log["reconstructions"] = xrec
198
+ log["inputs"] = x
199
+ return log
200
+
201
+
202
+ class IdentityFirstStage(torch.nn.Module):
203
+ def __init__(self, *args, vq_interface=False, **kwargs):
204
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
205
+ super().__init__()
206
+
207
+ def encode(self, x, *args, **kwargs):
208
+ return x
209
+
210
+ def decode(self, x, *args, **kwargs):
211
+ return x
212
+
213
+ def quantize(self, x, *args, **kwargs):
214
+ if self.vq_interface:
215
+ return x, None, [None, None, None]
216
+ return x
217
+
218
+ def forward(self, x, *args, **kwargs):
219
+ return x
swim/modules/attention.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ from swim.modules.diffusionmodules.util import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return {el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = (
53
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
54
+ if not glu
55
+ else GEGLU(dim, inner_dim)
56
+ )
57
+
58
+ self.net = nn.Sequential(
59
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
60
+ )
61
+
62
+ def forward(self, x):
63
+ return self.net(x)
64
+
65
+
66
+ def zero_module(module):
67
+ """
68
+ Zero out the parameters of a module and return it.
69
+ """
70
+ for p in module.parameters():
71
+ p.detach().zero_()
72
+ return module
73
+
74
+
75
+ def Normalize(in_channels):
76
+ return torch.nn.GroupNorm(
77
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
78
+ )
79
+
80
+
81
+ class LinearAttention(nn.Module):
82
+ def __init__(self, dim, heads=4, dim_head=32):
83
+ super().__init__()
84
+ self.heads = heads
85
+ hidden_dim = dim_head * heads
86
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
87
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
88
+
89
+ def forward(self, x):
90
+ b, c, h, w = x.shape
91
+ qkv = self.to_qkv(x)
92
+ q, k, v = rearrange(
93
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
94
+ )
95
+ k = k.softmax(dim=-1)
96
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
97
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
98
+ out = rearrange(
99
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
100
+ )
101
+ return self.to_out(out)
102
+
103
+
104
+ class SpatialSelfAttention(nn.Module):
105
+ def __init__(self, in_channels):
106
+ super().__init__()
107
+ self.in_channels = in_channels
108
+
109
+ self.norm = Normalize(in_channels)
110
+ self.q = torch.nn.Conv2d(
111
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
112
+ )
113
+ self.k = torch.nn.Conv2d(
114
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
115
+ )
116
+ self.v = torch.nn.Conv2d(
117
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
118
+ )
119
+ self.proj_out = torch.nn.Conv2d(
120
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
121
+ )
122
+
123
+ def forward(self, x):
124
+ h_ = x
125
+ h_ = self.norm(h_)
126
+ q = self.q(h_)
127
+ k = self.k(h_)
128
+ v = self.v(h_)
129
+
130
+ # compute attention
131
+ b, c, h, w = q.shape
132
+ q = rearrange(q, "b c h w -> b (h w) c")
133
+ k = rearrange(k, "b c h w -> b c (h w)")
134
+ w_ = torch.einsum("bij,bjk->bik", q, k)
135
+
136
+ w_ = w_ * (int(c) ** (-0.5))
137
+ w_ = torch.nn.functional.softmax(w_, dim=2)
138
+
139
+ # attend to values
140
+ v = rearrange(v, "b c h w -> b c (h w)")
141
+ w_ = rearrange(w_, "b i j -> b j i")
142
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
143
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
144
+ h_ = self.proj_out(h_)
145
+
146
+ return x + h_
147
+
148
+
149
+ class CrossAttention(nn.Module):
150
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
151
+ super().__init__()
152
+ inner_dim = dim_head * heads
153
+ context_dim = default(context_dim, query_dim)
154
+
155
+ self.scale = dim_head**-0.5
156
+ self.heads = heads
157
+
158
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
159
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
160
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
161
+
162
+ self.to_out = nn.Sequential(
163
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
164
+ )
165
+
166
+ def forward(self, x, context=None, mask=None):
167
+ h = self.heads
168
+
169
+ q = self.to_q(x)
170
+ context = default(context, x)
171
+ k = self.to_k(context)
172
+ v = self.to_v(context)
173
+
174
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
175
+
176
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
177
+
178
+ if exists(mask):
179
+ mask = rearrange(mask, "b ... -> b (...)")
180
+ max_neg_value = -torch.finfo(sim.dtype).max
181
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
182
+ sim.masked_fill_(~mask, max_neg_value)
183
+
184
+ # attention, what we cannot get enough of
185
+ attn = sim.softmax(dim=-1)
186
+
187
+ out = einsum("b i j, b j d -> b i d", attn, v)
188
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
189
+ return self.to_out(out)
190
+
191
+
192
+ class BasicTransformerBlock(nn.Module):
193
+ def __init__(
194
+ self,
195
+ dim,
196
+ n_heads,
197
+ d_head,
198
+ dropout=0.0,
199
+ context_dim=None,
200
+ gated_ff=True,
201
+ checkpoint=True,
202
+ ):
203
+ super().__init__()
204
+ self.attn1 = CrossAttention(
205
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
206
+ ) # is a self-attention
207
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
208
+ self.attn2 = CrossAttention(
209
+ query_dim=dim,
210
+ context_dim=context_dim,
211
+ heads=n_heads,
212
+ dim_head=d_head,
213
+ dropout=dropout,
214
+ ) # is self-attn if context is none
215
+ self.norm1 = nn.LayerNorm(dim)
216
+ self.norm2 = nn.LayerNorm(dim)
217
+ self.norm3 = nn.LayerNorm(dim)
218
+ self.checkpoint = checkpoint
219
+
220
+ def forward(self, x, context=None):
221
+ return checkpoint(
222
+ self._forward, (x, context), self.parameters(), self.checkpoint
223
+ )
224
+
225
+ def _forward(self, x, context=None):
226
+ x = self.attn1(self.norm1(x)) + x
227
+ x = self.attn2(self.norm2(x), context=context) + x
228
+ x = self.ff(self.norm3(x)) + x
229
+ return x
230
+
231
+
232
+ class SpatialTransformer(nn.Module):
233
+ """
234
+ Transformer block for image-like data.
235
+ First, project the input (aka embedding)
236
+ and reshape to b, t, d.
237
+ Then apply standard transformer action.
238
+ Finally, reshape to image
239
+ """
240
+
241
+ def __init__(
242
+ self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None
243
+ ):
244
+ super().__init__()
245
+ self.in_channels = in_channels
246
+ inner_dim = n_heads * d_head
247
+ self.norm = Normalize(in_channels)
248
+
249
+ self.proj_in = nn.Conv2d(
250
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
251
+ )
252
+
253
+ self.transformer_blocks = nn.ModuleList(
254
+ [
255
+ BasicTransformerBlock(
256
+ inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
257
+ )
258
+ for d in range(depth)
259
+ ]
260
+ )
261
+
262
+ self.proj_out = zero_module(
263
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
264
+ )
265
+
266
+ def forward(self, x, context=None):
267
+ # note: if no context is given, cross-attention defaults to self-attention
268
+ b, c, h, w = x.shape
269
+ x_in = x
270
+ x = self.norm(x)
271
+ x = self.proj_in(x)
272
+ x = rearrange(x, "b c h w -> b (h w) c")
273
+ for block in self.transformer_blocks:
274
+ x = block(x, context=context)
275
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
276
+ x = self.proj_out(x)
277
+ return x + x_in
swim/modules/dataset.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, List
2
+
3
+ import os
4
+ import json
5
+ import torch
6
+ import torchvision.transforms as T
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from PIL import Image
9
+ from lightning import LightningDataModule
10
+
11
+
12
+ class SwimDataset(Dataset):
13
+ def __init__(
14
+ self,
15
+ root_dir: str = "./datasets/swim_data",
16
+ split: Literal["train", "val"] = "train",
17
+ img_size: int = 512,
18
+ ):
19
+ super().__init__()
20
+ self.root_dir = root_dir
21
+ self.split_dir = os.path.join(root_dir, split)
22
+ self.img_size = img_size
23
+
24
+ if split == "train":
25
+ self.transform = T.Compose(
26
+ [
27
+ T.Resize(img_size), # smaller edge of image resized to img_size
28
+ T.RandomCrop(img_size), # get a random crop of img_size x img_size
29
+ T.RandomHorizontalFlip(),
30
+ T.ToTensor(),
31
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
32
+ ]
33
+ )
34
+ elif split == "val":
35
+ self.transform = T.Compose(
36
+ [
37
+ T.Resize(img_size),
38
+ T.CenterCrop(img_size),
39
+ T.ToTensor(),
40
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
41
+ ]
42
+ )
43
+
44
+ with open(os.path.join(self.split_dir, "labels.json"), "r") as f:
45
+ self.data = json.load(f)
46
+
47
+ # filter out images that are both at night and have adverse weather conditions
48
+ self.data = [
49
+ img
50
+ for img in self.data
51
+ if not (img["timeofday"] == "night" and img["weather"] != "clear")
52
+ ]
53
+
54
+ def __len__(self):
55
+ return len(self.data)
56
+
57
+ def __getitem__(self, idx):
58
+ data = self.data[idx]
59
+
60
+ # load image
61
+ img_path = os.path.join(self.split_dir, "images", data["name"])
62
+ img = Image.open(img_path).convert("RGB")
63
+ img = self.transform(img)
64
+
65
+ # load style
66
+ if data["weather"] != "clear":
67
+ style_name = data["weather"]
68
+ elif data["timeofday"] == "night":
69
+ style_name = "night"
70
+ else:
71
+ style_name = "clear"
72
+
73
+ # true if image has any styles
74
+ style_flag = style_name != "clear"
75
+
76
+ # one-hot encode style
77
+ style = torch.zeros(4)
78
+
79
+ if style_flag:
80
+ style[self.get_stylenames().index(style_name)] = 1
81
+
82
+ return {
83
+ "image": img,
84
+ "style": style,
85
+ "style_flag": style_flag,
86
+ }
87
+
88
+ def get_stylenames(self) -> List[str]:
89
+ return ["rain", "snow", "fog", "night"]
90
+
91
+
92
+ class SwimDataModule(LightningDataModule):
93
+ def __init__(
94
+ self,
95
+ root_dir: str = "./datasets/swim_data",
96
+ batch_size: int = 1,
97
+ img_size: int = 512,
98
+ ):
99
+ super().__init__()
100
+ self.root_dir = root_dir
101
+ self.img_size = img_size
102
+ self.batch_size = batch_size
103
+
104
+ def setup(self, stage=None):
105
+ if stage == "fit" or stage is None:
106
+ self.train_dataset = SwimDataset(
107
+ root_dir=self.root_dir, split="train", img_size=self.img_size
108
+ )
109
+ self.val_dataset = SwimDataset(
110
+ root_dir=self.root_dir, split="val", img_size=self.img_size
111
+ )
112
+
113
+ def train_dataloader(self):
114
+ return DataLoader(
115
+ self.train_dataset,
116
+ batch_size=self.batch_size,
117
+ shuffle=True,
118
+ num_workers=4,
119
+ collate_fn=self.custom_collate_fn,
120
+ )
121
+
122
+ def val_dataloader(self):
123
+ return DataLoader(
124
+ self.val_dataset,
125
+ batch_size=self.batch_size,
126
+ shuffle=False,
127
+ num_workers=4,
128
+ collate_fn=self.custom_collate_fn,
129
+ )
130
+
131
+ def test_dataloader(self):
132
+ return DataLoader(
133
+ self.val_dataset,
134
+ batch_size=1,
135
+ shuffle=False,
136
+ num_workers=4,
137
+ collate_fn=self.custom_collate_fn,
138
+ )
139
+
140
+ @staticmethod
141
+ def custom_collate_fn(batch):
142
+ images = torch.stack([item["image"] for item in batch])
143
+ styles = torch.stack([item["style"] for item in batch])
144
+ style_flags = [item["style_flag"] for item in batch]
145
+ return {"images": images, "styles": styles, "style_flags": style_flags}
swim/{__init__.py → modules/diffusionmodules/__init__.py} RENAMED
File without changes
swim/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ from swim.utils import instantiate_from_config
9
+ from swim.modules.distributions.distributions import DiagonalGaussianDistribution
10
+ from swim.modules.attention import LinearAttention
11
+
12
+
13
+ def get_timestep_embedding(timesteps, embedding_dim):
14
+ """
15
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
16
+ From Fairseq.
17
+ Build sinusoidal embeddings.
18
+ This matches the implementation in tensor2tensor, but differs slightly
19
+ from the description in Section 3.5 of "Attention Is All You Need".
20
+ """
21
+ assert len(timesteps.shape) == 1
22
+
23
+ half_dim = embedding_dim // 2
24
+ emb = math.log(10000) / (half_dim - 1)
25
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
26
+ emb = emb.to(device=timesteps.device)
27
+ emb = timesteps.float()[:, None] * emb[None, :]
28
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
29
+ if embedding_dim % 2 == 1: # zero pad
30
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
31
+ return emb
32
+
33
+
34
+ def nonlinearity(x):
35
+ # swish
36
+ return x * torch.sigmoid(x)
37
+
38
+
39
+ def Normalize(in_channels, num_groups=32):
40
+ return torch.nn.GroupNorm(
41
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
42
+ )
43
+
44
+
45
+ class Upsample(nn.Module):
46
+ def __init__(self, in_channels, with_conv):
47
+ super().__init__()
48
+ self.with_conv = with_conv
49
+ if self.with_conv:
50
+ self.conv = torch.nn.Conv2d(
51
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
52
+ )
53
+
54
+ def forward(self, x):
55
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
56
+ if self.with_conv:
57
+ x = self.conv(x)
58
+ return x
59
+
60
+
61
+ class Downsample(nn.Module):
62
+ def __init__(self, in_channels, with_conv):
63
+ super().__init__()
64
+ self.with_conv = with_conv
65
+ if self.with_conv:
66
+ # no asymmetric padding in torch conv, must do it ourselves
67
+ self.conv = torch.nn.Conv2d(
68
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
69
+ )
70
+
71
+ def forward(self, x):
72
+ if self.with_conv:
73
+ pad = (0, 1, 0, 1)
74
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
75
+ x = self.conv(x)
76
+ else:
77
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
78
+ return x
79
+
80
+
81
+ class ResnetBlock(nn.Module):
82
+ def __init__(
83
+ self,
84
+ *,
85
+ in_channels,
86
+ out_channels=None,
87
+ conv_shortcut=False,
88
+ dropout,
89
+ temb_channels=512,
90
+ ):
91
+ super().__init__()
92
+ self.in_channels = in_channels
93
+ out_channels = in_channels if out_channels is None else out_channels
94
+ self.out_channels = out_channels
95
+ self.use_conv_shortcut = conv_shortcut
96
+
97
+ self.norm1 = Normalize(in_channels)
98
+ self.conv1 = torch.nn.Conv2d(
99
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
100
+ )
101
+ if temb_channels > 0:
102
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
103
+ self.norm2 = Normalize(out_channels)
104
+ self.dropout = torch.nn.Dropout(dropout)
105
+ self.conv2 = torch.nn.Conv2d(
106
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
107
+ )
108
+ if self.in_channels != self.out_channels:
109
+ if self.use_conv_shortcut:
110
+ self.conv_shortcut = torch.nn.Conv2d(
111
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
112
+ )
113
+ else:
114
+ self.nin_shortcut = torch.nn.Conv2d(
115
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
116
+ )
117
+
118
+ def forward(self, x, temb):
119
+ h = x
120
+ h = self.norm1(h)
121
+ h = nonlinearity(h)
122
+ h = self.conv1(h)
123
+
124
+ if temb is not None:
125
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
126
+
127
+ h = self.norm2(h)
128
+ h = nonlinearity(h)
129
+ h = self.dropout(h)
130
+ h = self.conv2(h)
131
+
132
+ if self.in_channels != self.out_channels:
133
+ if self.use_conv_shortcut:
134
+ x = self.conv_shortcut(x)
135
+ else:
136
+ x = self.nin_shortcut(x)
137
+
138
+ return x + h
139
+
140
+
141
+ class LinAttnBlock(LinearAttention):
142
+ """to match AttnBlock usage"""
143
+
144
+ def __init__(self, in_channels):
145
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
146
+
147
+
148
+ class AttnBlock(nn.Module):
149
+ def __init__(self, in_channels):
150
+ super().__init__()
151
+ self.in_channels = in_channels
152
+
153
+ self.norm = Normalize(in_channels)
154
+ self.q = torch.nn.Conv2d(
155
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
156
+ )
157
+ self.k = torch.nn.Conv2d(
158
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
159
+ )
160
+ self.v = torch.nn.Conv2d(
161
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
162
+ )
163
+ self.proj_out = torch.nn.Conv2d(
164
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
165
+ )
166
+
167
+ def forward(self, x):
168
+ h_ = x
169
+ h_ = self.norm(h_)
170
+ q = self.q(h_)
171
+ k = self.k(h_)
172
+ v = self.v(h_)
173
+
174
+ # compute attention
175
+ b, c, h, w = q.shape
176
+ q = q.reshape(b, c, h * w)
177
+ q = q.permute(0, 2, 1) # b,hw,c
178
+ k = k.reshape(b, c, h * w) # b,c,hw
179
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
180
+ w_ = w_ * (int(c) ** (-0.5))
181
+ w_ = torch.nn.functional.softmax(w_, dim=2)
182
+
183
+ # attend to values
184
+ v = v.reshape(b, c, h * w)
185
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
186
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
187
+ h_ = h_.reshape(b, c, h, w)
188
+
189
+ h_ = self.proj_out(h_)
190
+
191
+ return x + h_
192
+
193
+
194
+ def make_attn(in_channels, attn_type="vanilla"):
195
+ assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
196
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
197
+ if attn_type == "vanilla":
198
+ return AttnBlock(in_channels)
199
+ elif attn_type == "none":
200
+ return nn.Identity(in_channels)
201
+ else:
202
+ return LinAttnBlock(in_channels)
203
+
204
+
205
+ class Model(nn.Module):
206
+ def __init__(
207
+ self,
208
+ *,
209
+ ch,
210
+ out_ch,
211
+ ch_mult=(1, 2, 4, 8),
212
+ num_res_blocks,
213
+ attn_resolutions,
214
+ dropout=0.0,
215
+ resamp_with_conv=True,
216
+ in_channels,
217
+ resolution,
218
+ use_timestep=True,
219
+ use_linear_attn=False,
220
+ attn_type="vanilla",
221
+ ):
222
+ super().__init__()
223
+ if use_linear_attn:
224
+ attn_type = "linear"
225
+ self.ch = ch
226
+ self.temb_ch = self.ch * 4
227
+ self.num_resolutions = len(ch_mult)
228
+ self.num_res_blocks = num_res_blocks
229
+ self.resolution = resolution
230
+ self.in_channels = in_channels
231
+
232
+ self.use_timestep = use_timestep
233
+ if self.use_timestep:
234
+ # timestep embedding
235
+ self.temb = nn.Module()
236
+ self.temb.dense = nn.ModuleList(
237
+ [
238
+ torch.nn.Linear(self.ch, self.temb_ch),
239
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
240
+ ]
241
+ )
242
+
243
+ # downsampling
244
+ self.conv_in = torch.nn.Conv2d(
245
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
246
+ )
247
+
248
+ curr_res = resolution
249
+ in_ch_mult = (1,) + tuple(ch_mult)
250
+ self.down = nn.ModuleList()
251
+ for i_level in range(self.num_resolutions):
252
+ block = nn.ModuleList()
253
+ attn = nn.ModuleList()
254
+ block_in = ch * in_ch_mult[i_level]
255
+ block_out = ch * ch_mult[i_level]
256
+ for i_block in range(self.num_res_blocks):
257
+ block.append(
258
+ ResnetBlock(
259
+ in_channels=block_in,
260
+ out_channels=block_out,
261
+ temb_channels=self.temb_ch,
262
+ dropout=dropout,
263
+ )
264
+ )
265
+ block_in = block_out
266
+ if curr_res in attn_resolutions:
267
+ attn.append(make_attn(block_in, attn_type=attn_type))
268
+ down = nn.Module()
269
+ down.block = block
270
+ down.attn = attn
271
+ if i_level != self.num_resolutions - 1:
272
+ down.downsample = Downsample(block_in, resamp_with_conv)
273
+ curr_res = curr_res // 2
274
+ self.down.append(down)
275
+
276
+ # middle
277
+ self.mid = nn.Module()
278
+ self.mid.block_1 = ResnetBlock(
279
+ in_channels=block_in,
280
+ out_channels=block_in,
281
+ temb_channels=self.temb_ch,
282
+ dropout=dropout,
283
+ )
284
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
285
+ self.mid.block_2 = ResnetBlock(
286
+ in_channels=block_in,
287
+ out_channels=block_in,
288
+ temb_channels=self.temb_ch,
289
+ dropout=dropout,
290
+ )
291
+
292
+ # upsampling
293
+ self.up = nn.ModuleList()
294
+ for i_level in reversed(range(self.num_resolutions)):
295
+ block = nn.ModuleList()
296
+ attn = nn.ModuleList()
297
+ block_out = ch * ch_mult[i_level]
298
+ skip_in = ch * ch_mult[i_level]
299
+ for i_block in range(self.num_res_blocks + 1):
300
+ if i_block == self.num_res_blocks:
301
+ skip_in = ch * in_ch_mult[i_level]
302
+ block.append(
303
+ ResnetBlock(
304
+ in_channels=block_in + skip_in,
305
+ out_channels=block_out,
306
+ temb_channels=self.temb_ch,
307
+ dropout=dropout,
308
+ )
309
+ )
310
+ block_in = block_out
311
+ if curr_res in attn_resolutions:
312
+ attn.append(make_attn(block_in, attn_type=attn_type))
313
+ up = nn.Module()
314
+ up.block = block
315
+ up.attn = attn
316
+ if i_level != 0:
317
+ up.upsample = Upsample(block_in, resamp_with_conv)
318
+ curr_res = curr_res * 2
319
+ self.up.insert(0, up) # prepend to get consistent order
320
+
321
+ # end
322
+ self.norm_out = Normalize(block_in)
323
+ self.conv_out = torch.nn.Conv2d(
324
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
325
+ )
326
+
327
+ def forward(self, x, t=None, context=None):
328
+ # assert x.shape[2] == x.shape[3] == self.resolution
329
+ if context is not None:
330
+ # assume aligned context, cat along channel axis
331
+ x = torch.cat((x, context), dim=1)
332
+ if self.use_timestep:
333
+ # timestep embedding
334
+ assert t is not None
335
+ temb = get_timestep_embedding(t, self.ch)
336
+ temb = self.temb.dense[0](temb)
337
+ temb = nonlinearity(temb)
338
+ temb = self.temb.dense[1](temb)
339
+ else:
340
+ temb = None
341
+
342
+ # downsampling
343
+ hs = [self.conv_in(x)]
344
+ for i_level in range(self.num_resolutions):
345
+ for i_block in range(self.num_res_blocks):
346
+ h = self.down[i_level].block[i_block](hs[-1], temb)
347
+ if len(self.down[i_level].attn) > 0:
348
+ h = self.down[i_level].attn[i_block](h)
349
+ hs.append(h)
350
+ if i_level != self.num_resolutions - 1:
351
+ hs.append(self.down[i_level].downsample(hs[-1]))
352
+
353
+ # middle
354
+ h = hs[-1]
355
+ h = self.mid.block_1(h, temb)
356
+ h = self.mid.attn_1(h)
357
+ h = self.mid.block_2(h, temb)
358
+
359
+ # upsampling
360
+ for i_level in reversed(range(self.num_resolutions)):
361
+ for i_block in range(self.num_res_blocks + 1):
362
+ h = self.up[i_level].block[i_block](
363
+ torch.cat([h, hs.pop()], dim=1), temb
364
+ )
365
+ if len(self.up[i_level].attn) > 0:
366
+ h = self.up[i_level].attn[i_block](h)
367
+ if i_level != 0:
368
+ h = self.up[i_level].upsample(h)
369
+
370
+ # end
371
+ h = self.norm_out(h)
372
+ h = nonlinearity(h)
373
+ h = self.conv_out(h)
374
+ return h
375
+
376
+ def get_last_layer(self):
377
+ return self.conv_out.weight
378
+
379
+
380
+ class Encoder(nn.Module):
381
+ def __init__(
382
+ self,
383
+ *,
384
+ ch,
385
+ out_ch,
386
+ ch_mult=(1, 2, 4, 8),
387
+ num_res_blocks,
388
+ attn_resolutions,
389
+ dropout=0.0,
390
+ resamp_with_conv=True,
391
+ in_channels,
392
+ resolution,
393
+ z_channels,
394
+ double_z=True,
395
+ use_linear_attn=False,
396
+ attn_type="vanilla",
397
+ **ignore_kwargs,
398
+ ):
399
+ super().__init__()
400
+ if use_linear_attn:
401
+ attn_type = "linear"
402
+ self.ch = ch
403
+ self.temb_ch = 0
404
+ self.num_resolutions = len(ch_mult)
405
+ self.num_res_blocks = num_res_blocks
406
+ self.resolution = resolution
407
+ self.in_channels = in_channels
408
+
409
+ # downsampling
410
+ self.conv_in = torch.nn.Conv2d(
411
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
412
+ )
413
+
414
+ curr_res = resolution
415
+ in_ch_mult = (1,) + tuple(ch_mult)
416
+ self.in_ch_mult = in_ch_mult
417
+ self.down = nn.ModuleList()
418
+ for i_level in range(self.num_resolutions):
419
+ block = nn.ModuleList()
420
+ attn = nn.ModuleList()
421
+ block_in = ch * in_ch_mult[i_level]
422
+ block_out = ch * ch_mult[i_level]
423
+ for i_block in range(self.num_res_blocks):
424
+ block.append(
425
+ ResnetBlock(
426
+ in_channels=block_in,
427
+ out_channels=block_out,
428
+ temb_channels=self.temb_ch,
429
+ dropout=dropout,
430
+ )
431
+ )
432
+ block_in = block_out
433
+ if curr_res in attn_resolutions:
434
+ attn.append(make_attn(block_in, attn_type=attn_type))
435
+ down = nn.Module()
436
+ down.block = block
437
+ down.attn = attn
438
+ if i_level != self.num_resolutions - 1:
439
+ down.downsample = Downsample(block_in, resamp_with_conv)
440
+ curr_res = curr_res // 2
441
+ self.down.append(down)
442
+
443
+ # middle
444
+ self.mid = nn.Module()
445
+ self.mid.block_1 = ResnetBlock(
446
+ in_channels=block_in,
447
+ out_channels=block_in,
448
+ temb_channels=self.temb_ch,
449
+ dropout=dropout,
450
+ )
451
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
452
+ self.mid.block_2 = ResnetBlock(
453
+ in_channels=block_in,
454
+ out_channels=block_in,
455
+ temb_channels=self.temb_ch,
456
+ dropout=dropout,
457
+ )
458
+
459
+ # end
460
+ self.norm_out = Normalize(block_in)
461
+ self.conv_out = torch.nn.Conv2d(
462
+ block_in,
463
+ 2 * z_channels if double_z else z_channels,
464
+ kernel_size=3,
465
+ stride=1,
466
+ padding=1,
467
+ )
468
+
469
+ def forward(self, x):
470
+ # timestep embedding
471
+ temb = None
472
+
473
+ # downsampling
474
+ hs = [self.conv_in(x)]
475
+ for i_level in range(self.num_resolutions):
476
+ for i_block in range(self.num_res_blocks):
477
+ h = self.down[i_level].block[i_block](hs[-1], temb)
478
+ if len(self.down[i_level].attn) > 0:
479
+ h = self.down[i_level].attn[i_block](h)
480
+ hs.append(h)
481
+ if i_level != self.num_resolutions - 1:
482
+ hs.append(self.down[i_level].downsample(hs[-1]))
483
+
484
+ # middle
485
+ h = hs[-1]
486
+ h = self.mid.block_1(h, temb)
487
+ h = self.mid.attn_1(h)
488
+ h = self.mid.block_2(h, temb)
489
+
490
+ # end
491
+ h = self.norm_out(h)
492
+ h = nonlinearity(h)
493
+ h = self.conv_out(h)
494
+ return h
495
+
496
+
497
+ class Decoder(nn.Module):
498
+ def __init__(
499
+ self,
500
+ *,
501
+ ch,
502
+ out_ch,
503
+ ch_mult=(1, 2, 4, 8),
504
+ num_res_blocks,
505
+ attn_resolutions,
506
+ dropout=0.0,
507
+ resamp_with_conv=True,
508
+ in_channels,
509
+ resolution,
510
+ z_channels,
511
+ give_pre_end=False,
512
+ tanh_out=False,
513
+ use_linear_attn=False,
514
+ attn_type="vanilla",
515
+ **ignorekwargs,
516
+ ):
517
+ super().__init__()
518
+ if use_linear_attn:
519
+ attn_type = "linear"
520
+ self.ch = ch
521
+ self.temb_ch = 0
522
+ self.num_resolutions = len(ch_mult)
523
+ self.num_res_blocks = num_res_blocks
524
+ self.resolution = resolution
525
+ self.in_channels = in_channels
526
+ self.give_pre_end = give_pre_end
527
+ self.tanh_out = tanh_out
528
+
529
+ # compute in_ch_mult, block_in and curr_res at lowest res
530
+ in_ch_mult = (1,) + tuple(ch_mult)
531
+ block_in = ch * ch_mult[self.num_resolutions - 1]
532
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
533
+ self.z_shape = (1, z_channels, curr_res, curr_res)
534
+ print(
535
+ "Working with z of shape {} = {} dimensions.".format(
536
+ self.z_shape, np.prod(self.z_shape)
537
+ )
538
+ )
539
+
540
+ # z to block_in
541
+ self.conv_in = torch.nn.Conv2d(
542
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
543
+ )
544
+
545
+ # middle
546
+ self.mid = nn.Module()
547
+ self.mid.block_1 = ResnetBlock(
548
+ in_channels=block_in,
549
+ out_channels=block_in,
550
+ temb_channels=self.temb_ch,
551
+ dropout=dropout,
552
+ )
553
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
554
+ self.mid.block_2 = ResnetBlock(
555
+ in_channels=block_in,
556
+ out_channels=block_in,
557
+ temb_channels=self.temb_ch,
558
+ dropout=dropout,
559
+ )
560
+
561
+ # upsampling
562
+ self.up = nn.ModuleList()
563
+ for i_level in reversed(range(self.num_resolutions)):
564
+ block = nn.ModuleList()
565
+ attn = nn.ModuleList()
566
+ block_out = ch * ch_mult[i_level]
567
+ for i_block in range(self.num_res_blocks + 1):
568
+ block.append(
569
+ ResnetBlock(
570
+ in_channels=block_in,
571
+ out_channels=block_out,
572
+ temb_channels=self.temb_ch,
573
+ dropout=dropout,
574
+ )
575
+ )
576
+ block_in = block_out
577
+ if curr_res in attn_resolutions:
578
+ attn.append(make_attn(block_in, attn_type=attn_type))
579
+ up = nn.Module()
580
+ up.block = block
581
+ up.attn = attn
582
+ if i_level != 0:
583
+ up.upsample = Upsample(block_in, resamp_with_conv)
584
+ curr_res = curr_res * 2
585
+ self.up.insert(0, up) # prepend to get consistent order
586
+
587
+ # end
588
+ self.norm_out = Normalize(block_in)
589
+ self.conv_out = torch.nn.Conv2d(
590
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
591
+ )
592
+
593
+ def forward(self, z):
594
+ # assert z.shape[1:] == self.z_shape[1:]
595
+ self.last_z_shape = z.shape
596
+
597
+ # timestep embedding
598
+ temb = None
599
+
600
+ # z to block_in
601
+ h = self.conv_in(z)
602
+
603
+ # middle
604
+ h = self.mid.block_1(h, temb)
605
+ h = self.mid.attn_1(h)
606
+ h = self.mid.block_2(h, temb)
607
+
608
+ # upsampling
609
+ for i_level in reversed(range(self.num_resolutions)):
610
+ for i_block in range(self.num_res_blocks + 1):
611
+ h = self.up[i_level].block[i_block](h, temb)
612
+ if len(self.up[i_level].attn) > 0:
613
+ h = self.up[i_level].attn[i_block](h)
614
+ if i_level != 0:
615
+ h = self.up[i_level].upsample(h)
616
+
617
+ # end
618
+ if self.give_pre_end:
619
+ return h
620
+
621
+ h = self.norm_out(h)
622
+ h = nonlinearity(h)
623
+ h = self.conv_out(h)
624
+ if self.tanh_out:
625
+ h = torch.tanh(h)
626
+ return h
627
+
628
+
629
+ class SimpleDecoder(nn.Module):
630
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
631
+ super().__init__()
632
+ self.model = nn.ModuleList(
633
+ [
634
+ nn.Conv2d(in_channels, in_channels, 1),
635
+ ResnetBlock(
636
+ in_channels=in_channels,
637
+ out_channels=2 * in_channels,
638
+ temb_channels=0,
639
+ dropout=0.0,
640
+ ),
641
+ ResnetBlock(
642
+ in_channels=2 * in_channels,
643
+ out_channels=4 * in_channels,
644
+ temb_channels=0,
645
+ dropout=0.0,
646
+ ),
647
+ ResnetBlock(
648
+ in_channels=4 * in_channels,
649
+ out_channels=2 * in_channels,
650
+ temb_channels=0,
651
+ dropout=0.0,
652
+ ),
653
+ nn.Conv2d(2 * in_channels, in_channels, 1),
654
+ Upsample(in_channels, with_conv=True),
655
+ ]
656
+ )
657
+ # end
658
+ self.norm_out = Normalize(in_channels)
659
+ self.conv_out = torch.nn.Conv2d(
660
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
661
+ )
662
+
663
+ def forward(self, x):
664
+ for i, layer in enumerate(self.model):
665
+ if i in [1, 2, 3]:
666
+ x = layer(x, None)
667
+ else:
668
+ x = layer(x)
669
+
670
+ h = self.norm_out(x)
671
+ h = nonlinearity(h)
672
+ x = self.conv_out(h)
673
+ return x
674
+
675
+
676
+ class UpsampleDecoder(nn.Module):
677
+ def __init__(
678
+ self,
679
+ in_channels,
680
+ out_channels,
681
+ ch,
682
+ num_res_blocks,
683
+ resolution,
684
+ ch_mult=(2, 2),
685
+ dropout=0.0,
686
+ ):
687
+ super().__init__()
688
+ # upsampling
689
+ self.temb_ch = 0
690
+ self.num_resolutions = len(ch_mult)
691
+ self.num_res_blocks = num_res_blocks
692
+ block_in = in_channels
693
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
694
+ self.res_blocks = nn.ModuleList()
695
+ self.upsample_blocks = nn.ModuleList()
696
+ for i_level in range(self.num_resolutions):
697
+ res_block = []
698
+ block_out = ch * ch_mult[i_level]
699
+ for i_block in range(self.num_res_blocks + 1):
700
+ res_block.append(
701
+ ResnetBlock(
702
+ in_channels=block_in,
703
+ out_channels=block_out,
704
+ temb_channels=self.temb_ch,
705
+ dropout=dropout,
706
+ )
707
+ )
708
+ block_in = block_out
709
+ self.res_blocks.append(nn.ModuleList(res_block))
710
+ if i_level != self.num_resolutions - 1:
711
+ self.upsample_blocks.append(Upsample(block_in, True))
712
+ curr_res = curr_res * 2
713
+
714
+ # end
715
+ self.norm_out = Normalize(block_in)
716
+ self.conv_out = torch.nn.Conv2d(
717
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
718
+ )
719
+
720
+ def forward(self, x):
721
+ # upsampling
722
+ h = x
723
+ for k, i_level in enumerate(range(self.num_resolutions)):
724
+ for i_block in range(self.num_res_blocks + 1):
725
+ h = self.res_blocks[i_level][i_block](h, None)
726
+ if i_level != self.num_resolutions - 1:
727
+ h = self.upsample_blocks[k](h)
728
+ h = self.norm_out(h)
729
+ h = nonlinearity(h)
730
+ h = self.conv_out(h)
731
+ return h
732
+
733
+
734
+ class LatentRescaler(nn.Module):
735
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
736
+ super().__init__()
737
+ # residual block, interpolate, residual block
738
+ self.factor = factor
739
+ self.conv_in = nn.Conv2d(
740
+ in_channels, mid_channels, kernel_size=3, stride=1, padding=1
741
+ )
742
+ self.res_block1 = nn.ModuleList(
743
+ [
744
+ ResnetBlock(
745
+ in_channels=mid_channels,
746
+ out_channels=mid_channels,
747
+ temb_channels=0,
748
+ dropout=0.0,
749
+ )
750
+ for _ in range(depth)
751
+ ]
752
+ )
753
+ self.attn = AttnBlock(mid_channels)
754
+ self.res_block2 = nn.ModuleList(
755
+ [
756
+ ResnetBlock(
757
+ in_channels=mid_channels,
758
+ out_channels=mid_channels,
759
+ temb_channels=0,
760
+ dropout=0.0,
761
+ )
762
+ for _ in range(depth)
763
+ ]
764
+ )
765
+
766
+ self.conv_out = nn.Conv2d(
767
+ mid_channels,
768
+ out_channels,
769
+ kernel_size=1,
770
+ )
771
+
772
+ def forward(self, x):
773
+ x = self.conv_in(x)
774
+ for block in self.res_block1:
775
+ x = block(x, None)
776
+ x = torch.nn.functional.interpolate(
777
+ x,
778
+ size=(
779
+ int(round(x.shape[2] * self.factor)),
780
+ int(round(x.shape[3] * self.factor)),
781
+ ),
782
+ )
783
+ x = self.attn(x)
784
+ for block in self.res_block2:
785
+ x = block(x, None)
786
+ x = self.conv_out(x)
787
+ return x
788
+
789
+
790
+ class MergedRescaleEncoder(nn.Module):
791
+ def __init__(
792
+ self,
793
+ in_channels,
794
+ ch,
795
+ resolution,
796
+ out_ch,
797
+ num_res_blocks,
798
+ attn_resolutions,
799
+ dropout=0.0,
800
+ resamp_with_conv=True,
801
+ ch_mult=(1, 2, 4, 8),
802
+ rescale_factor=1.0,
803
+ rescale_module_depth=1,
804
+ ):
805
+ super().__init__()
806
+ intermediate_chn = ch * ch_mult[-1]
807
+ self.encoder = Encoder(
808
+ in_channels=in_channels,
809
+ num_res_blocks=num_res_blocks,
810
+ ch=ch,
811
+ ch_mult=ch_mult,
812
+ z_channels=intermediate_chn,
813
+ double_z=False,
814
+ resolution=resolution,
815
+ attn_resolutions=attn_resolutions,
816
+ dropout=dropout,
817
+ resamp_with_conv=resamp_with_conv,
818
+ out_ch=None,
819
+ )
820
+ self.rescaler = LatentRescaler(
821
+ factor=rescale_factor,
822
+ in_channels=intermediate_chn,
823
+ mid_channels=intermediate_chn,
824
+ out_channels=out_ch,
825
+ depth=rescale_module_depth,
826
+ )
827
+
828
+ def forward(self, x):
829
+ x = self.encoder(x)
830
+ x = self.rescaler(x)
831
+ return x
832
+
833
+
834
+ class MergedRescaleDecoder(nn.Module):
835
+ def __init__(
836
+ self,
837
+ z_channels,
838
+ out_ch,
839
+ resolution,
840
+ num_res_blocks,
841
+ attn_resolutions,
842
+ ch,
843
+ ch_mult=(1, 2, 4, 8),
844
+ dropout=0.0,
845
+ resamp_with_conv=True,
846
+ rescale_factor=1.0,
847
+ rescale_module_depth=1,
848
+ ):
849
+ super().__init__()
850
+ tmp_chn = z_channels * ch_mult[-1]
851
+ self.decoder = Decoder(
852
+ out_ch=out_ch,
853
+ z_channels=tmp_chn,
854
+ attn_resolutions=attn_resolutions,
855
+ dropout=dropout,
856
+ resamp_with_conv=resamp_with_conv,
857
+ in_channels=None,
858
+ num_res_blocks=num_res_blocks,
859
+ ch_mult=ch_mult,
860
+ resolution=resolution,
861
+ ch=ch,
862
+ )
863
+ self.rescaler = LatentRescaler(
864
+ factor=rescale_factor,
865
+ in_channels=z_channels,
866
+ mid_channels=tmp_chn,
867
+ out_channels=tmp_chn,
868
+ depth=rescale_module_depth,
869
+ )
870
+
871
+ def forward(self, x):
872
+ x = self.rescaler(x)
873
+ x = self.decoder(x)
874
+ return x
875
+
876
+
877
+ class Upsampler(nn.Module):
878
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
879
+ super().__init__()
880
+ assert out_size >= in_size
881
+ num_blocks = int(np.log2(out_size // in_size)) + 1
882
+ factor_up = 1.0 + (out_size % in_size)
883
+ print(
884
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
885
+ )
886
+ self.rescaler = LatentRescaler(
887
+ factor=factor_up,
888
+ in_channels=in_channels,
889
+ mid_channels=2 * in_channels,
890
+ out_channels=in_channels,
891
+ )
892
+ self.decoder = Decoder(
893
+ out_ch=out_channels,
894
+ resolution=out_size,
895
+ z_channels=in_channels,
896
+ num_res_blocks=2,
897
+ attn_resolutions=[],
898
+ in_channels=None,
899
+ ch=in_channels,
900
+ ch_mult=[ch_mult for _ in range(num_blocks)],
901
+ )
902
+
903
+ def forward(self, x):
904
+ x = self.rescaler(x)
905
+ x = self.decoder(x)
906
+ return x
907
+
908
+
909
+ class Resize(nn.Module):
910
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
911
+ super().__init__()
912
+ self.with_conv = learned
913
+ self.mode = mode
914
+ if self.with_conv:
915
+ print(
916
+ f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
917
+ )
918
+ raise NotImplementedError()
919
+ assert in_channels is not None
920
+ # no asymmetric padding in torch conv, must do it ourselves
921
+ self.conv = torch.nn.Conv2d(
922
+ in_channels, in_channels, kernel_size=4, stride=2, padding=1
923
+ )
924
+
925
+ def forward(self, x, scale_factor=1.0):
926
+ if scale_factor == 1.0:
927
+ return x
928
+ else:
929
+ x = torch.nn.functional.interpolate(
930
+ x, mode=self.mode, align_corners=False, scale_factor=scale_factor
931
+ )
932
+ return x
933
+
934
+
935
+ class FirstStagePostProcessor(nn.Module):
936
+
937
+ def __init__(
938
+ self,
939
+ ch_mult: list,
940
+ in_channels,
941
+ pretrained_model: nn.Module = None,
942
+ reshape=False,
943
+ n_channels=None,
944
+ dropout=0.0,
945
+ pretrained_config=None,
946
+ ):
947
+ super().__init__()
948
+ if pretrained_config is None:
949
+ assert (
950
+ pretrained_model is not None
951
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
952
+ self.pretrained_model = pretrained_model
953
+ else:
954
+ assert (
955
+ pretrained_config is not None
956
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
957
+ self.instantiate_pretrained(pretrained_config)
958
+
959
+ self.do_reshape = reshape
960
+
961
+ if n_channels is None:
962
+ n_channels = self.pretrained_model.encoder.ch
963
+
964
+ self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
965
+ self.proj = nn.Conv2d(
966
+ in_channels, n_channels, kernel_size=3, stride=1, padding=1
967
+ )
968
+
969
+ blocks = []
970
+ downs = []
971
+ ch_in = n_channels
972
+ for m in ch_mult:
973
+ blocks.append(
974
+ ResnetBlock(
975
+ in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
976
+ )
977
+ )
978
+ ch_in = m * n_channels
979
+ downs.append(Downsample(ch_in, with_conv=False))
980
+
981
+ self.model = nn.ModuleList(blocks)
982
+ self.downsampler = nn.ModuleList(downs)
983
+
984
+ def instantiate_pretrained(self, config):
985
+ model = instantiate_from_config(config)
986
+ self.pretrained_model = model.eval()
987
+ # self.pretrained_model.train = False
988
+ for param in self.pretrained_model.parameters():
989
+ param.requires_grad = False
990
+
991
+ @torch.no_grad()
992
+ def encode_with_pretrained(self, x):
993
+ c = self.pretrained_model.encode(x)
994
+ if isinstance(c, DiagonalGaussianDistribution):
995
+ c = c.mode()
996
+ return c
997
+
998
+ def forward(self, x):
999
+ z_fs = self.encode_with_pretrained(x)
1000
+ z = self.proj_norm(z_fs)
1001
+ z = self.proj(z)
1002
+ z = nonlinearity(z)
1003
+
1004
+ for submodel, downmodel in zip(self.model, self.downsampler):
1005
+ z = submodel(z, temb=None)
1006
+ z = downmodel(z)
1007
+
1008
+ if self.do_reshape:
1009
+ z = rearrange(z, "b c h w -> b (h w) c")
1010
+ return z
swim/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,1012 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+ import math
4
+ from typing import Iterable
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from swim.modules.diffusionmodules.util import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+ from swim.modules.attention import SpatialTransformer
21
+
22
+
23
+ # dummy replace
24
+ def convert_module_to_f16(x):
25
+ pass
26
+
27
+
28
+ def convert_module_to_f32(x):
29
+ pass
30
+
31
+
32
+ ## go
33
+ class AttentionPool2d(nn.Module):
34
+ """
35
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ spacial_dim: int,
41
+ embed_dim: int,
42
+ num_heads_channels: int,
43
+ output_dim: int = None,
44
+ ):
45
+ super().__init__()
46
+ self.positional_embedding = nn.Parameter(
47
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
48
+ )
49
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
50
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
51
+ self.num_heads = embed_dim // num_heads_channels
52
+ self.attention = QKVAttention(self.num_heads)
53
+
54
+ def forward(self, x):
55
+ b, c, *_spatial = x.shape
56
+ x = x.reshape(b, c, -1) # NC(HW)
57
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
58
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
59
+ x = self.qkv_proj(x)
60
+ x = self.attention(x)
61
+ x = self.c_proj(x)
62
+ return x[:, :, 0]
63
+
64
+
65
+ class TimestepBlock(nn.Module):
66
+ """
67
+ Any module where forward() takes timestep embeddings as a second argument.
68
+ """
69
+
70
+ @abstractmethod
71
+ def forward(self, x, emb):
72
+ """
73
+ Apply the module to `x` given `emb` timestep embeddings.
74
+ """
75
+
76
+
77
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
78
+ """
79
+ A sequential module that passes timestep embeddings to the children that
80
+ support it as an extra input.
81
+ """
82
+
83
+ def forward(self, x, emb, context=None):
84
+ for layer in self:
85
+ if isinstance(layer, TimestepBlock):
86
+ x = layer(x, emb)
87
+ elif isinstance(layer, SpatialTransformer):
88
+ x = layer(x, context)
89
+ else:
90
+ x = layer(x)
91
+ return x
92
+
93
+
94
+ class Upsample(nn.Module):
95
+ """
96
+ An upsampling layer with an optional convolution.
97
+ :param channels: channels in the inputs and outputs.
98
+ :param use_conv: a bool determining if a convolution is applied.
99
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
100
+ upsampling occurs in the inner-two dimensions.
101
+ """
102
+
103
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
104
+ super().__init__()
105
+ self.channels = channels
106
+ self.out_channels = out_channels or channels
107
+ self.use_conv = use_conv
108
+ self.dims = dims
109
+ if use_conv:
110
+ self.conv = conv_nd(
111
+ dims, self.channels, self.out_channels, 3, padding=padding
112
+ )
113
+
114
+ def forward(self, x):
115
+ assert x.shape[1] == self.channels
116
+ if self.dims == 3:
117
+ x = F.interpolate(
118
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
119
+ )
120
+ else:
121
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
122
+ if self.use_conv:
123
+ x = self.conv(x)
124
+ return x
125
+
126
+
127
+ class TransposedUpsample(nn.Module):
128
+ "Learned 2x upsampling without padding"
129
+
130
+ def __init__(self, channels, out_channels=None, ks=5):
131
+ super().__init__()
132
+ self.channels = channels
133
+ self.out_channels = out_channels or channels
134
+
135
+ self.up = nn.ConvTranspose2d(
136
+ self.channels, self.out_channels, kernel_size=ks, stride=2
137
+ )
138
+
139
+ def forward(self, x):
140
+ return self.up(x)
141
+
142
+
143
+ class Downsample(nn.Module):
144
+ """
145
+ A downsampling layer with an optional convolution.
146
+ :param channels: channels in the inputs and outputs.
147
+ :param use_conv: a bool determining if a convolution is applied.
148
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
149
+ downsampling occurs in the inner-two dimensions.
150
+ """
151
+
152
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
153
+ super().__init__()
154
+ self.channels = channels
155
+ self.out_channels = out_channels or channels
156
+ self.use_conv = use_conv
157
+ self.dims = dims
158
+ stride = 2 if dims != 3 else (1, 2, 2)
159
+ if use_conv:
160
+ self.op = conv_nd(
161
+ dims,
162
+ self.channels,
163
+ self.out_channels,
164
+ 3,
165
+ stride=stride,
166
+ padding=padding,
167
+ )
168
+ else:
169
+ assert self.channels == self.out_channels
170
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
171
+
172
+ def forward(self, x):
173
+ assert x.shape[1] == self.channels
174
+ return self.op(x)
175
+
176
+
177
+ class ResBlock(TimestepBlock):
178
+ """
179
+ A residual block that can optionally change the number of channels.
180
+ :param channels: the number of input channels.
181
+ :param emb_channels: the number of timestep embedding channels.
182
+ :param dropout: the rate of dropout.
183
+ :param out_channels: if specified, the number of out channels.
184
+ :param use_conv: if True and out_channels is specified, use a spatial
185
+ convolution instead of a smaller 1x1 convolution to change the
186
+ channels in the skip connection.
187
+ :param dims: determines if the signal is 1D, 2D, or 3D.
188
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
189
+ :param up: if True, use this block for upsampling.
190
+ :param down: if True, use this block for downsampling.
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ channels,
196
+ emb_channels,
197
+ dropout,
198
+ out_channels=None,
199
+ use_conv=False,
200
+ use_scale_shift_norm=False,
201
+ dims=2,
202
+ use_checkpoint=False,
203
+ up=False,
204
+ down=False,
205
+ ):
206
+ super().__init__()
207
+ self.channels = channels
208
+ self.emb_channels = emb_channels
209
+ self.dropout = dropout
210
+ self.out_channels = out_channels or channels
211
+ self.use_conv = use_conv
212
+ self.use_checkpoint = use_checkpoint
213
+ self.use_scale_shift_norm = use_scale_shift_norm
214
+
215
+ self.in_layers = nn.Sequential(
216
+ normalization(channels),
217
+ nn.SiLU(),
218
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
219
+ )
220
+
221
+ self.updown = up or down
222
+
223
+ if up:
224
+ self.h_upd = Upsample(channels, False, dims)
225
+ self.x_upd = Upsample(channels, False, dims)
226
+ elif down:
227
+ self.h_upd = Downsample(channels, False, dims)
228
+ self.x_upd = Downsample(channels, False, dims)
229
+ else:
230
+ self.h_upd = self.x_upd = nn.Identity()
231
+
232
+ self.emb_layers = nn.Sequential(
233
+ nn.SiLU(),
234
+ linear(
235
+ emb_channels,
236
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
237
+ ),
238
+ )
239
+ self.out_layers = nn.Sequential(
240
+ normalization(self.out_channels),
241
+ nn.SiLU(),
242
+ nn.Dropout(p=dropout),
243
+ zero_module(
244
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
245
+ ),
246
+ )
247
+
248
+ if self.out_channels == channels:
249
+ self.skip_connection = nn.Identity()
250
+ elif use_conv:
251
+ self.skip_connection = conv_nd(
252
+ dims, channels, self.out_channels, 3, padding=1
253
+ )
254
+ else:
255
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
256
+
257
+ def forward(self, x, emb):
258
+ """
259
+ Apply the block to a Tensor, conditioned on a timestep embedding.
260
+ :param x: an [N x C x ...] Tensor of features.
261
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
262
+ :return: an [N x C x ...] Tensor of outputs.
263
+ """
264
+ return checkpoint(
265
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
266
+ )
267
+
268
+ def _forward(self, x, emb):
269
+ if self.updown:
270
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
271
+ h = in_rest(x)
272
+ h = self.h_upd(h)
273
+ x = self.x_upd(x)
274
+ h = in_conv(h)
275
+ else:
276
+ h = self.in_layers(x)
277
+ emb_out = self.emb_layers(emb).type(h.dtype)
278
+ while len(emb_out.shape) < len(h.shape):
279
+ emb_out = emb_out[..., None]
280
+ if self.use_scale_shift_norm:
281
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
282
+ scale, shift = th.chunk(emb_out, 2, dim=1)
283
+ h = out_norm(h) * (1 + scale) + shift
284
+ h = out_rest(h)
285
+ else:
286
+ h = h + emb_out
287
+ h = self.out_layers(h)
288
+ return self.skip_connection(x) + h
289
+
290
+
291
+ class AttentionBlock(nn.Module):
292
+ """
293
+ An attention block that allows spatial positions to attend to each other.
294
+ Originally ported from here, but adapted to the N-d case.
295
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
296
+ """
297
+
298
+ def __init__(
299
+ self,
300
+ channels,
301
+ num_heads=1,
302
+ num_head_channels=-1,
303
+ use_checkpoint=False,
304
+ use_new_attention_order=False,
305
+ ):
306
+ super().__init__()
307
+ self.channels = channels
308
+ if num_head_channels == -1:
309
+ self.num_heads = num_heads
310
+ else:
311
+ assert (
312
+ channels % num_head_channels == 0
313
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
314
+ self.num_heads = channels // num_head_channels
315
+ self.use_checkpoint = use_checkpoint
316
+ self.norm = normalization(channels)
317
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
318
+ if use_new_attention_order:
319
+ # split qkv before split heads
320
+ self.attention = QKVAttention(self.num_heads)
321
+ else:
322
+ # split heads before split qkv
323
+ self.attention = QKVAttentionLegacy(self.num_heads)
324
+
325
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
326
+
327
+ def forward(self, x):
328
+ return checkpoint(
329
+ self._forward, (x,), self.parameters(), True
330
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
331
+ # return pt_checkpoint(self._forward, x) # pytorch
332
+
333
+ def _forward(self, x):
334
+ b, c, *spatial = x.shape
335
+ x = x.reshape(b, c, -1)
336
+ qkv = self.qkv(self.norm(x))
337
+ h = self.attention(qkv)
338
+ h = self.proj_out(h)
339
+ return (x + h).reshape(b, c, *spatial)
340
+
341
+
342
+ def count_flops_attn(model, _x, y):
343
+ """
344
+ A counter for the `thop` package to count the operations in an
345
+ attention operation.
346
+ Meant to be used like:
347
+ macs, params = thop.profile(
348
+ model,
349
+ inputs=(inputs, timestamps),
350
+ custom_ops={QKVAttention: QKVAttention.count_flops},
351
+ )
352
+ """
353
+ b, c, *spatial = y[0].shape
354
+ num_spatial = int(np.prod(spatial))
355
+ # We perform two matmuls with the same number of ops.
356
+ # The first computes the weight matrix, the second computes
357
+ # the combination of the value vectors.
358
+ matmul_ops = 2 * b * (num_spatial**2) * c
359
+ model.total_ops += th.DoubleTensor([matmul_ops])
360
+
361
+
362
+ class QKVAttentionLegacy(nn.Module):
363
+ """
364
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
365
+ """
366
+
367
+ def __init__(self, n_heads):
368
+ super().__init__()
369
+ self.n_heads = n_heads
370
+
371
+ def forward(self, qkv):
372
+ """
373
+ Apply QKV attention.
374
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
375
+ :return: an [N x (H * C) x T] tensor after attention.
376
+ """
377
+ bs, width, length = qkv.shape
378
+ assert width % (3 * self.n_heads) == 0
379
+ ch = width // (3 * self.n_heads)
380
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
381
+ scale = 1 / math.sqrt(math.sqrt(ch))
382
+ weight = th.einsum(
383
+ "bct,bcs->bts", q * scale, k * scale
384
+ ) # More stable with f16 than dividing afterwards
385
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
386
+ a = th.einsum("bts,bcs->bct", weight, v)
387
+ return a.reshape(bs, -1, length)
388
+
389
+ @staticmethod
390
+ def count_flops(model, _x, y):
391
+ return count_flops_attn(model, _x, y)
392
+
393
+
394
+ class QKVAttention(nn.Module):
395
+ """
396
+ A module which performs QKV attention and splits in a different order.
397
+ """
398
+
399
+ def __init__(self, n_heads):
400
+ super().__init__()
401
+ self.n_heads = n_heads
402
+
403
+ def forward(self, qkv):
404
+ """
405
+ Apply QKV attention.
406
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
407
+ :return: an [N x (H * C) x T] tensor after attention.
408
+ """
409
+ bs, width, length = qkv.shape
410
+ assert width % (3 * self.n_heads) == 0
411
+ ch = width // (3 * self.n_heads)
412
+ q, k, v = qkv.chunk(3, dim=1)
413
+ scale = 1 / math.sqrt(math.sqrt(ch))
414
+ weight = th.einsum(
415
+ "bct,bcs->bts",
416
+ (q * scale).view(bs * self.n_heads, ch, length),
417
+ (k * scale).view(bs * self.n_heads, ch, length),
418
+ ) # More stable with f16 than dividing afterwards
419
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
420
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
421
+ return a.reshape(bs, -1, length)
422
+
423
+ @staticmethod
424
+ def count_flops(model, _x, y):
425
+ return count_flops_attn(model, _x, y)
426
+
427
+
428
+ class UNetModel(nn.Module):
429
+ """
430
+ The full UNet model with attention and timestep embedding.
431
+ :param in_channels: channels in the input Tensor.
432
+ :param model_channels: base channel count for the model.
433
+ :param out_channels: channels in the output Tensor.
434
+ :param num_res_blocks: number of residual blocks per downsample.
435
+ :param attention_resolutions: a collection of downsample rates at which
436
+ attention will take place. May be a set, list, or tuple.
437
+ For example, if this contains 4, then at 4x downsampling, attention
438
+ will be used.
439
+ :param dropout: the dropout probability.
440
+ :param channel_mult: channel multiplier for each level of the UNet.
441
+ :param conv_resample: if True, use learned convolutions for upsampling and
442
+ downsampling.
443
+ :param dims: determines if the signal is 1D, 2D, or 3D.
444
+ :param num_classes: if specified (as an int), then this model will be
445
+ class-conditional with `num_classes` classes.
446
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
447
+ :param num_heads: the number of attention heads in each attention layer.
448
+ :param num_heads_channels: if specified, ignore num_heads and instead use
449
+ a fixed channel width per attention head.
450
+ :param num_heads_upsample: works with num_heads to set a different number
451
+ of heads for upsampling. Deprecated.
452
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
453
+ :param resblock_updown: use residual blocks for up/downsampling.
454
+ :param use_new_attention_order: use a different attention pattern for potentially
455
+ increased efficiency.
456
+ """
457
+
458
+ def __init__(
459
+ self,
460
+ image_size,
461
+ in_channels,
462
+ model_channels,
463
+ out_channels,
464
+ num_res_blocks,
465
+ attention_resolutions,
466
+ dropout=0,
467
+ channel_mult=(1, 2, 4, 8),
468
+ conv_resample=True,
469
+ dims=2,
470
+ num_classes=None,
471
+ use_checkpoint=False,
472
+ use_fp16=False,
473
+ num_heads=-1,
474
+ num_head_channels=-1,
475
+ num_heads_upsample=-1,
476
+ use_scale_shift_norm=False,
477
+ resblock_updown=False,
478
+ use_new_attention_order=False,
479
+ use_spatial_transformer=False, # custom transformer support
480
+ transformer_depth=1, # custom transformer support
481
+ context_dim=None, # custom transformer support
482
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
483
+ legacy=True,
484
+ ):
485
+ super().__init__()
486
+ if use_spatial_transformer:
487
+ assert (
488
+ context_dim is not None
489
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
490
+
491
+ if context_dim is not None:
492
+ assert (
493
+ use_spatial_transformer
494
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
495
+ from omegaconf.listconfig import ListConfig
496
+
497
+ if type(context_dim) == ListConfig:
498
+ context_dim = list(context_dim)
499
+
500
+ if num_heads_upsample == -1:
501
+ num_heads_upsample = num_heads
502
+
503
+ if num_heads == -1:
504
+ assert (
505
+ num_head_channels != -1
506
+ ), "Either num_heads or num_head_channels has to be set"
507
+
508
+ if num_head_channels == -1:
509
+ assert (
510
+ num_heads != -1
511
+ ), "Either num_heads or num_head_channels has to be set"
512
+
513
+ self.image_size = image_size
514
+ self.in_channels = in_channels
515
+ self.model_channels = model_channels
516
+ self.out_channels = out_channels
517
+ self.num_res_blocks = num_res_blocks
518
+ self.attention_resolutions = attention_resolutions
519
+ self.dropout = dropout
520
+ self.channel_mult = channel_mult
521
+ self.conv_resample = conv_resample
522
+ self.num_classes = num_classes
523
+ self.use_checkpoint = use_checkpoint
524
+ self.dtype = th.float16 if use_fp16 else th.float32
525
+ self.num_heads = num_heads
526
+ self.num_head_channels = num_head_channels
527
+ self.num_heads_upsample = num_heads_upsample
528
+ self.predict_codebook_ids = n_embed is not None
529
+
530
+ time_embed_dim = model_channels * 4
531
+ self.time_embed = nn.Sequential(
532
+ linear(model_channels, time_embed_dim),
533
+ nn.SiLU(),
534
+ linear(time_embed_dim, time_embed_dim),
535
+ )
536
+
537
+ if self.num_classes is not None:
538
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
539
+
540
+ self.input_blocks = nn.ModuleList(
541
+ [
542
+ TimestepEmbedSequential(
543
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
544
+ )
545
+ ]
546
+ )
547
+ self._feature_size = model_channels
548
+ input_block_chans = [model_channels]
549
+ ch = model_channels
550
+ ds = 1
551
+ for level, mult in enumerate(channel_mult):
552
+ for _ in range(num_res_blocks):
553
+ layers = [
554
+ ResBlock(
555
+ ch,
556
+ time_embed_dim,
557
+ dropout,
558
+ out_channels=mult * model_channels,
559
+ dims=dims,
560
+ use_checkpoint=use_checkpoint,
561
+ use_scale_shift_norm=use_scale_shift_norm,
562
+ )
563
+ ]
564
+ ch = mult * model_channels
565
+ if ds in attention_resolutions:
566
+ if num_head_channels == -1:
567
+ dim_head = ch // num_heads
568
+ else:
569
+ num_heads = ch // num_head_channels
570
+ dim_head = num_head_channels
571
+ if legacy:
572
+ # num_heads = 1
573
+ dim_head = (
574
+ ch // num_heads
575
+ if use_spatial_transformer
576
+ else num_head_channels
577
+ )
578
+ layers.append(
579
+ AttentionBlock(
580
+ ch,
581
+ use_checkpoint=use_checkpoint,
582
+ num_heads=num_heads,
583
+ num_head_channels=dim_head,
584
+ use_new_attention_order=use_new_attention_order,
585
+ )
586
+ if not use_spatial_transformer
587
+ else SpatialTransformer(
588
+ ch,
589
+ num_heads,
590
+ dim_head,
591
+ depth=transformer_depth,
592
+ context_dim=context_dim,
593
+ )
594
+ )
595
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
596
+ self._feature_size += ch
597
+ input_block_chans.append(ch)
598
+ if level != len(channel_mult) - 1:
599
+ out_ch = ch
600
+ self.input_blocks.append(
601
+ TimestepEmbedSequential(
602
+ ResBlock(
603
+ ch,
604
+ time_embed_dim,
605
+ dropout,
606
+ out_channels=out_ch,
607
+ dims=dims,
608
+ use_checkpoint=use_checkpoint,
609
+ use_scale_shift_norm=use_scale_shift_norm,
610
+ down=True,
611
+ )
612
+ if resblock_updown
613
+ else Downsample(
614
+ ch, conv_resample, dims=dims, out_channels=out_ch
615
+ )
616
+ )
617
+ )
618
+ ch = out_ch
619
+ input_block_chans.append(ch)
620
+ ds *= 2
621
+ self._feature_size += ch
622
+
623
+ if num_head_channels == -1:
624
+ dim_head = ch // num_heads
625
+ else:
626
+ num_heads = ch // num_head_channels
627
+ dim_head = num_head_channels
628
+ if legacy:
629
+ # num_heads = 1
630
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
631
+ self.middle_block = TimestepEmbedSequential(
632
+ ResBlock(
633
+ ch,
634
+ time_embed_dim,
635
+ dropout,
636
+ dims=dims,
637
+ use_checkpoint=use_checkpoint,
638
+ use_scale_shift_norm=use_scale_shift_norm,
639
+ ),
640
+ (
641
+ AttentionBlock(
642
+ ch,
643
+ use_checkpoint=use_checkpoint,
644
+ num_heads=num_heads,
645
+ num_head_channels=dim_head,
646
+ use_new_attention_order=use_new_attention_order,
647
+ )
648
+ if not use_spatial_transformer
649
+ else SpatialTransformer(
650
+ ch,
651
+ num_heads,
652
+ dim_head,
653
+ depth=transformer_depth,
654
+ context_dim=context_dim,
655
+ )
656
+ ),
657
+ ResBlock(
658
+ ch,
659
+ time_embed_dim,
660
+ dropout,
661
+ dims=dims,
662
+ use_checkpoint=use_checkpoint,
663
+ use_scale_shift_norm=use_scale_shift_norm,
664
+ ),
665
+ )
666
+ self._feature_size += ch
667
+
668
+ self.output_blocks = nn.ModuleList([])
669
+ for level, mult in list(enumerate(channel_mult))[::-1]:
670
+ for i in range(num_res_blocks + 1):
671
+ ich = input_block_chans.pop()
672
+ layers = [
673
+ ResBlock(
674
+ ch + ich,
675
+ time_embed_dim,
676
+ dropout,
677
+ out_channels=model_channels * mult,
678
+ dims=dims,
679
+ use_checkpoint=use_checkpoint,
680
+ use_scale_shift_norm=use_scale_shift_norm,
681
+ )
682
+ ]
683
+ ch = model_channels * mult
684
+ if ds in attention_resolutions:
685
+ if num_head_channels == -1:
686
+ dim_head = ch // num_heads
687
+ else:
688
+ num_heads = ch // num_head_channels
689
+ dim_head = num_head_channels
690
+ if legacy:
691
+ # num_heads = 1
692
+ dim_head = (
693
+ ch // num_heads
694
+ if use_spatial_transformer
695
+ else num_head_channels
696
+ )
697
+ layers.append(
698
+ AttentionBlock(
699
+ ch,
700
+ use_checkpoint=use_checkpoint,
701
+ num_heads=num_heads_upsample,
702
+ num_head_channels=dim_head,
703
+ use_new_attention_order=use_new_attention_order,
704
+ )
705
+ if not use_spatial_transformer
706
+ else SpatialTransformer(
707
+ ch,
708
+ num_heads,
709
+ dim_head,
710
+ depth=transformer_depth,
711
+ context_dim=context_dim,
712
+ )
713
+ )
714
+ if level and i == num_res_blocks:
715
+ out_ch = ch
716
+ layers.append(
717
+ ResBlock(
718
+ ch,
719
+ time_embed_dim,
720
+ dropout,
721
+ out_channels=out_ch,
722
+ dims=dims,
723
+ use_checkpoint=use_checkpoint,
724
+ use_scale_shift_norm=use_scale_shift_norm,
725
+ up=True,
726
+ )
727
+ if resblock_updown
728
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
729
+ )
730
+ ds //= 2
731
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
732
+ self._feature_size += ch
733
+
734
+ self.out = nn.Sequential(
735
+ normalization(ch),
736
+ nn.SiLU(),
737
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
738
+ )
739
+ if self.predict_codebook_ids:
740
+ self.id_predictor = nn.Sequential(
741
+ normalization(ch),
742
+ conv_nd(dims, model_channels, n_embed, 1),
743
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
744
+ )
745
+
746
+ def convert_to_fp16(self):
747
+ """
748
+ Convert the torso of the model to float16.
749
+ """
750
+ self.input_blocks.apply(convert_module_to_f16)
751
+ self.middle_block.apply(convert_module_to_f16)
752
+ self.output_blocks.apply(convert_module_to_f16)
753
+
754
+ def convert_to_fp32(self):
755
+ """
756
+ Convert the torso of the model to float32.
757
+ """
758
+ self.input_blocks.apply(convert_module_to_f32)
759
+ self.middle_block.apply(convert_module_to_f32)
760
+ self.output_blocks.apply(convert_module_to_f32)
761
+
762
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
763
+ """
764
+ Apply the model to an input batch.
765
+ :param x: an [N x C x ...] Tensor of inputs.
766
+ :param timesteps: a 1-D batch of timesteps.
767
+ :param context: conditioning plugged in via crossattn
768
+ :param y: an [N] Tensor of labels, if class-conditional.
769
+ :return: an [N x C x ...] Tensor of outputs.
770
+ """
771
+ assert (y is not None) == (
772
+ self.num_classes is not None
773
+ ), "must specify y if and only if the model is class-conditional"
774
+ hs = []
775
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
776
+ emb = self.time_embed(t_emb)
777
+
778
+ if self.num_classes is not None:
779
+ assert y.shape == (x.shape[0],)
780
+ emb = emb + self.label_emb(y)
781
+
782
+ h = x.type(self.dtype)
783
+ for module in self.input_blocks:
784
+ h = module(h, emb, context)
785
+ hs.append(h)
786
+ h = self.middle_block(h, emb, context)
787
+ for module in self.output_blocks:
788
+ h = th.cat([h, hs.pop()], dim=1)
789
+ h = module(h, emb, context)
790
+ h = h.type(x.dtype)
791
+ if self.predict_codebook_ids:
792
+ return self.id_predictor(h)
793
+ else:
794
+ return self.out(h)
795
+
796
+
797
+ class EncoderUNetModel(nn.Module):
798
+ """
799
+ The half UNet model with attention and timestep embedding.
800
+ For usage, see UNet.
801
+ """
802
+
803
+ def __init__(
804
+ self,
805
+ image_size,
806
+ in_channels,
807
+ model_channels,
808
+ out_channels,
809
+ num_res_blocks,
810
+ attention_resolutions,
811
+ dropout=0,
812
+ channel_mult=(1, 2, 4, 8),
813
+ conv_resample=True,
814
+ dims=2,
815
+ use_checkpoint=False,
816
+ use_fp16=False,
817
+ num_heads=1,
818
+ num_head_channels=-1,
819
+ num_heads_upsample=-1,
820
+ use_scale_shift_norm=False,
821
+ resblock_updown=False,
822
+ use_new_attention_order=False,
823
+ pool="adaptive",
824
+ *args,
825
+ **kwargs,
826
+ ):
827
+ super().__init__()
828
+
829
+ if num_heads_upsample == -1:
830
+ num_heads_upsample = num_heads
831
+
832
+ self.in_channels = in_channels
833
+ self.model_channels = model_channels
834
+ self.out_channels = out_channels
835
+ self.num_res_blocks = num_res_blocks
836
+ self.attention_resolutions = attention_resolutions
837
+ self.dropout = dropout
838
+ self.channel_mult = channel_mult
839
+ self.conv_resample = conv_resample
840
+ self.use_checkpoint = use_checkpoint
841
+ self.dtype = th.float16 if use_fp16 else th.float32
842
+ self.num_heads = num_heads
843
+ self.num_head_channels = num_head_channels
844
+ self.num_heads_upsample = num_heads_upsample
845
+
846
+ time_embed_dim = model_channels * 4
847
+ self.time_embed = nn.Sequential(
848
+ linear(model_channels, time_embed_dim),
849
+ nn.SiLU(),
850
+ linear(time_embed_dim, time_embed_dim),
851
+ )
852
+
853
+ self.input_blocks = nn.ModuleList(
854
+ [
855
+ TimestepEmbedSequential(
856
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
857
+ )
858
+ ]
859
+ )
860
+ self._feature_size = model_channels
861
+ input_block_chans = [model_channels]
862
+ ch = model_channels
863
+ ds = 1
864
+ for level, mult in enumerate(channel_mult):
865
+ for _ in range(num_res_blocks):
866
+ layers = [
867
+ ResBlock(
868
+ ch,
869
+ time_embed_dim,
870
+ dropout,
871
+ out_channels=mult * model_channels,
872
+ dims=dims,
873
+ use_checkpoint=use_checkpoint,
874
+ use_scale_shift_norm=use_scale_shift_norm,
875
+ )
876
+ ]
877
+ ch = mult * model_channels
878
+ if ds in attention_resolutions:
879
+ layers.append(
880
+ AttentionBlock(
881
+ ch,
882
+ use_checkpoint=use_checkpoint,
883
+ num_heads=num_heads,
884
+ num_head_channels=num_head_channels,
885
+ use_new_attention_order=use_new_attention_order,
886
+ )
887
+ )
888
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
889
+ self._feature_size += ch
890
+ input_block_chans.append(ch)
891
+ if level != len(channel_mult) - 1:
892
+ out_ch = ch
893
+ self.input_blocks.append(
894
+ TimestepEmbedSequential(
895
+ ResBlock(
896
+ ch,
897
+ time_embed_dim,
898
+ dropout,
899
+ out_channels=out_ch,
900
+ dims=dims,
901
+ use_checkpoint=use_checkpoint,
902
+ use_scale_shift_norm=use_scale_shift_norm,
903
+ down=True,
904
+ )
905
+ if resblock_updown
906
+ else Downsample(
907
+ ch, conv_resample, dims=dims, out_channels=out_ch
908
+ )
909
+ )
910
+ )
911
+ ch = out_ch
912
+ input_block_chans.append(ch)
913
+ ds *= 2
914
+ self._feature_size += ch
915
+
916
+ self.middle_block = TimestepEmbedSequential(
917
+ ResBlock(
918
+ ch,
919
+ time_embed_dim,
920
+ dropout,
921
+ dims=dims,
922
+ use_checkpoint=use_checkpoint,
923
+ use_scale_shift_norm=use_scale_shift_norm,
924
+ ),
925
+ AttentionBlock(
926
+ ch,
927
+ use_checkpoint=use_checkpoint,
928
+ num_heads=num_heads,
929
+ num_head_channels=num_head_channels,
930
+ use_new_attention_order=use_new_attention_order,
931
+ ),
932
+ ResBlock(
933
+ ch,
934
+ time_embed_dim,
935
+ dropout,
936
+ dims=dims,
937
+ use_checkpoint=use_checkpoint,
938
+ use_scale_shift_norm=use_scale_shift_norm,
939
+ ),
940
+ )
941
+ self._feature_size += ch
942
+ self.pool = pool
943
+ if pool == "adaptive":
944
+ self.out = nn.Sequential(
945
+ normalization(ch),
946
+ nn.SiLU(),
947
+ nn.AdaptiveAvgPool2d((1, 1)),
948
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
949
+ nn.Flatten(),
950
+ )
951
+ elif pool == "attention":
952
+ assert num_head_channels != -1
953
+ self.out = nn.Sequential(
954
+ normalization(ch),
955
+ nn.SiLU(),
956
+ AttentionPool2d(
957
+ (image_size // ds), ch, num_head_channels, out_channels
958
+ ),
959
+ )
960
+ elif pool == "spatial":
961
+ self.out = nn.Sequential(
962
+ nn.Linear(self._feature_size, 2048),
963
+ nn.ReLU(),
964
+ nn.Linear(2048, self.out_channels),
965
+ )
966
+ elif pool == "spatial_v2":
967
+ self.out = nn.Sequential(
968
+ nn.Linear(self._feature_size, 2048),
969
+ normalization(2048),
970
+ nn.SiLU(),
971
+ nn.Linear(2048, self.out_channels),
972
+ )
973
+ else:
974
+ raise NotImplementedError(f"Unexpected {pool} pooling")
975
+
976
+ def convert_to_fp16(self):
977
+ """
978
+ Convert the torso of the model to float16.
979
+ """
980
+ self.input_blocks.apply(convert_module_to_f16)
981
+ self.middle_block.apply(convert_module_to_f16)
982
+
983
+ def convert_to_fp32(self):
984
+ """
985
+ Convert the torso of the model to float32.
986
+ """
987
+ self.input_blocks.apply(convert_module_to_f32)
988
+ self.middle_block.apply(convert_module_to_f32)
989
+
990
+ def forward(self, x, timesteps):
991
+ """
992
+ Apply the model to an input batch.
993
+ :param x: an [N x C x ...] Tensor of inputs.
994
+ :param timesteps: a 1-D batch of timesteps.
995
+ :return: an [N x K] Tensor of outputs.
996
+ """
997
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
998
+
999
+ results = []
1000
+ h = x.type(self.dtype)
1001
+ for module in self.input_blocks:
1002
+ h = module(h, emb)
1003
+ if self.pool.startswith("spatial"):
1004
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1005
+ h = self.middle_block(h, emb)
1006
+ if self.pool.startswith("spatial"):
1007
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1008
+ h = th.cat(results, axis=-1)
1009
+ return self.out(h)
1010
+ else:
1011
+ h = h.type(x.dtype)
1012
+ return self.out(h)
swim/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ from swim.utils import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(
22
+ schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
23
+ ):
24
+ if schedule == "linear":
25
+ betas = (
26
+ torch.linspace(
27
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
28
+ )
29
+ ** 2
30
+ )
31
+
32
+ elif schedule == "cosine":
33
+ timesteps = (
34
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
35
+ )
36
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
37
+ alphas = torch.cos(alphas).pow(2)
38
+ alphas = alphas / alphas[0]
39
+ betas = 1 - alphas[1:] / alphas[:-1]
40
+ betas = np.clip(betas, a_min=0, a_max=0.999)
41
+
42
+ elif schedule == "sqrt_linear":
43
+ betas = torch.linspace(
44
+ linear_start, linear_end, n_timestep, dtype=torch.float64
45
+ )
46
+ elif schedule == "sqrt":
47
+ betas = (
48
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
49
+ ** 0.5
50
+ )
51
+ else:
52
+ raise ValueError(f"schedule '{schedule}' unknown.")
53
+ return betas.numpy()
54
+
55
+
56
+ def make_ddim_timesteps(
57
+ ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
58
+ ):
59
+ if ddim_discr_method == "uniform":
60
+ c = num_ddpm_timesteps // num_ddim_timesteps
61
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
62
+ elif ddim_discr_method == "quad":
63
+ ddim_timesteps = (
64
+ (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
65
+ ).astype(int)
66
+ else:
67
+ raise NotImplementedError(
68
+ f'There is no ddim discretization method called "{ddim_discr_method}"'
69
+ )
70
+
71
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
72
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
73
+ steps_out = ddim_timesteps + 1
74
+ if verbose:
75
+ print(f"Selected timesteps for ddim sampler: {steps_out}")
76
+ return steps_out
77
+
78
+
79
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
80
+ # select alphas for computing the variance schedule
81
+ alphas = alphacums[ddim_timesteps]
82
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
83
+
84
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
85
+ sigmas = eta * np.sqrt(
86
+ (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
87
+ )
88
+ if verbose:
89
+ print(
90
+ f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
91
+ )
92
+ print(
93
+ f"For the chosen value of eta, which is {eta}, "
94
+ f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
95
+ )
96
+ return sigmas, alphas, alphas_prev
97
+
98
+
99
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
100
+ """
101
+ Create a beta schedule that discretizes the given alpha_t_bar function,
102
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
103
+ :param num_diffusion_timesteps: the number of betas to produce.
104
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
105
+ produces the cumulative product of (1-beta) up to that
106
+ part of the diffusion process.
107
+ :param max_beta: the maximum beta to use; use values lower than 1 to
108
+ prevent singularities.
109
+ """
110
+ betas = []
111
+ for i in range(num_diffusion_timesteps):
112
+ t1 = i / num_diffusion_timesteps
113
+ t2 = (i + 1) / num_diffusion_timesteps
114
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
115
+ return np.array(betas)
116
+
117
+
118
+ def extract_into_tensor(a, t, x_shape):
119
+ b, *_ = t.shape
120
+ out = a.gather(-1, t)
121
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
122
+
123
+
124
+ def checkpoint(func, inputs, params, flag):
125
+ """
126
+ Evaluate a function without caching intermediate activations, allowing for
127
+ reduced memory at the expense of extra compute in the backward pass.
128
+ :param func: the function to evaluate.
129
+ :param inputs: the argument sequence to pass to `func`.
130
+ :param params: a sequence of parameters `func` depends on but does not
131
+ explicitly take as arguments.
132
+ :param flag: if False, disable gradient checkpointing.
133
+ """
134
+ if flag:
135
+ args = tuple(inputs) + tuple(params)
136
+ return CheckpointFunction.apply(func, len(inputs), *args)
137
+ else:
138
+ return func(*inputs)
139
+
140
+
141
+ class CheckpointFunction(torch.autograd.Function):
142
+ @staticmethod
143
+ def forward(ctx, run_function, length, *args):
144
+ ctx.run_function = run_function
145
+ ctx.input_tensors = list(args[:length])
146
+ ctx.input_params = list(args[length:])
147
+
148
+ with torch.no_grad():
149
+ output_tensors = ctx.run_function(*ctx.input_tensors)
150
+ return output_tensors
151
+
152
+ @staticmethod
153
+ def backward(ctx, *output_grads):
154
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155
+ with torch.enable_grad():
156
+ # Fixes a bug where the first op in run_function modifies the
157
+ # Tensor storage in place, which is not allowed for detach()'d
158
+ # Tensors.
159
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160
+ output_tensors = ctx.run_function(*shallow_copies)
161
+ input_grads = torch.autograd.grad(
162
+ output_tensors,
163
+ ctx.input_tensors + ctx.input_params,
164
+ output_grads,
165
+ allow_unused=True,
166
+ )
167
+ del ctx.input_tensors
168
+ del ctx.input_params
169
+ del output_tensors
170
+ return (None, None) + input_grads
171
+
172
+
173
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
174
+ """
175
+ Create sinusoidal timestep embeddings.
176
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
177
+ These may be fractional.
178
+ :param dim: the dimension of the output.
179
+ :param max_period: controls the minimum frequency of the embeddings.
180
+ :return: an [N x dim] Tensor of positional embeddings.
181
+ """
182
+ if not repeat_only:
183
+ half = dim // 2
184
+ freqs = torch.exp(
185
+ -math.log(max_period)
186
+ * torch.arange(start=0, end=half, dtype=torch.float32)
187
+ / half
188
+ ).to(device=timesteps.device)
189
+ args = timesteps[:, None].float() * freqs[None]
190
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
191
+ if dim % 2:
192
+ embedding = torch.cat(
193
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
194
+ )
195
+ else:
196
+ embedding = repeat(timesteps, "b -> b d", d=dim)
197
+ return embedding
198
+
199
+
200
+ def zero_module(module):
201
+ """
202
+ Zero out the parameters of a module and return it.
203
+ """
204
+ for p in module.parameters():
205
+ p.detach().zero_()
206
+ return module
207
+
208
+
209
+ def scale_module(module, scale):
210
+ """
211
+ Scale the parameters of a module and return it.
212
+ """
213
+ for p in module.parameters():
214
+ p.detach().mul_(scale)
215
+ return module
216
+
217
+
218
+ def mean_flat(tensor):
219
+ """
220
+ Take the mean over all non-batch dimensions.
221
+ """
222
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
223
+
224
+
225
+ def normalization(channels):
226
+ """
227
+ Make a standard normalization layer.
228
+ :param channels: number of input channels.
229
+ :return: an nn.Module for normalization.
230
+ """
231
+ return GroupNorm32(32, channels)
232
+
233
+
234
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
235
+ class SiLU(nn.Module):
236
+ def forward(self, x):
237
+ return x * torch.sigmoid(x)
238
+
239
+
240
+ class GroupNorm32(nn.GroupNorm):
241
+ def forward(self, x):
242
+ return super().forward(x.float()).type(x.dtype)
243
+
244
+
245
+ def conv_nd(dims, *args, **kwargs):
246
+ """
247
+ Create a 1D, 2D, or 3D convolution module.
248
+ """
249
+ if dims == 1:
250
+ return nn.Conv1d(*args, **kwargs)
251
+ elif dims == 2:
252
+ return nn.Conv2d(*args, **kwargs)
253
+ elif dims == 3:
254
+ return nn.Conv3d(*args, **kwargs)
255
+ raise ValueError(f"unsupported dimensions: {dims}")
256
+
257
+
258
+ def linear(*args, **kwargs):
259
+ """
260
+ Create a linear module.
261
+ """
262
+ return nn.Linear(*args, **kwargs)
263
+
264
+
265
+ def avg_pool_nd(dims, *args, **kwargs):
266
+ """
267
+ Create a 1D, 2D, or 3D average pooling module.
268
+ """
269
+ if dims == 1:
270
+ return nn.AvgPool1d(*args, **kwargs)
271
+ elif dims == 2:
272
+ return nn.AvgPool2d(*args, **kwargs)
273
+ elif dims == 3:
274
+ return nn.AvgPool3d(*args, **kwargs)
275
+ raise ValueError(f"unsupported dimensions: {dims}")
276
+
277
+
278
+ class HybridConditioner(nn.Module):
279
+
280
+ def __init__(self, c_concat_config, c_crossattn_config):
281
+ super().__init__()
282
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
283
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
284
+
285
+ def forward(self, c_concat, c_crossattn):
286
+ c_concat = self.concat_conditioner(c_concat)
287
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
288
+ return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
289
+
290
+
291
+ def noise_like(shape, device, repeat=False):
292
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
293
+ shape[0], *((1,) * (len(shape) - 1))
294
+ )
295
+ noise = lambda: torch.randn(shape, device=device)
296
+ return repeat_noise() if repeat else noise()
swim/modules/discriminators/n_layer_discriminator.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch.nn as nn
3
+
4
+
5
+ from swim.utils import ActNorm
6
+
7
+
8
+ def weights_init(m):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
12
+ elif classname.find("BatchNorm") != -1:
13
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
14
+ nn.init.constant_(m.bias.data, 0)
15
+
16
+
17
+ class NLayerDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+
22
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
23
+ """Construct a PatchGAN discriminator
24
+ Parameters:
25
+ input_nc (int) -- the number of channels in input images
26
+ ndf (int) -- the number of filters in the last conv layer
27
+ n_layers (int) -- the number of conv layers in the discriminator
28
+ norm_layer -- normalization layer
29
+ """
30
+ super(NLayerDiscriminator, self).__init__()
31
+ if not use_actnorm:
32
+ norm_layer = nn.BatchNorm2d
33
+ else:
34
+ norm_layer = ActNorm
35
+ if (
36
+ type(norm_layer) == functools.partial
37
+ ): # no need to use bias as BatchNorm2d has affine parameters
38
+ use_bias = norm_layer.func != nn.BatchNorm2d
39
+ else:
40
+ use_bias = norm_layer != nn.BatchNorm2d
41
+
42
+ kw = 4
43
+ padw = 1
44
+ sequence = [
45
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
46
+ nn.LeakyReLU(0.2, True),
47
+ ]
48
+ nf_mult = 1
49
+ nf_mult_prev = 1
50
+ for n in range(1, n_layers): # gradually increase the number of filters
51
+ nf_mult_prev = nf_mult
52
+ nf_mult = min(2**n, 8)
53
+ sequence += [
54
+ nn.Conv2d(
55
+ ndf * nf_mult_prev,
56
+ ndf * nf_mult,
57
+ kernel_size=kw,
58
+ stride=2,
59
+ padding=padw,
60
+ bias=use_bias,
61
+ ),
62
+ norm_layer(ndf * nf_mult),
63
+ nn.LeakyReLU(0.2, True),
64
+ ]
65
+
66
+ nf_mult_prev = nf_mult
67
+ nf_mult = min(2**n_layers, 8)
68
+ sequence += [
69
+ nn.Conv2d(
70
+ ndf * nf_mult_prev,
71
+ ndf * nf_mult,
72
+ kernel_size=kw,
73
+ stride=1,
74
+ padding=padw,
75
+ bias=use_bias,
76
+ ),
77
+ norm_layer(ndf * nf_mult),
78
+ nn.LeakyReLU(0.2, True),
79
+ ]
80
+
81
+ sequence += [
82
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
83
+ ] # output 1 channel prediction map
84
+ self.main = nn.Sequential(*sequence)
85
+
86
+ def forward(self, input):
87
+ """Standard forward."""
88
+ return self.main(input)
swim/modules/distributions/__init__.py ADDED
File without changes
swim/modules/distributions/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )
swim/modules/ema.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1,dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ #remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.','')
20
+ self.m_name2s_name.update({name:s_name})
21
+ self.register_buffer(s_name,p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def forward(self,model):
26
+ decay = self.decay
27
+
28
+ if self.num_updates >= 0:
29
+ self.num_updates += 1
30
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31
+
32
+ one_minus_decay = 1.0 - decay
33
+
34
+ with torch.no_grad():
35
+ m_param = dict(model.named_parameters())
36
+ shadow_params = dict(self.named_buffers())
37
+
38
+ for key in m_param:
39
+ if m_param[key].requires_grad:
40
+ sname = self.m_name2s_name[key]
41
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43
+ else:
44
+ assert not key in self.m_name2s_name
45
+
46
+ def copy_to(self, model):
47
+ m_param = dict(model.named_parameters())
48
+ shadow_params = dict(self.named_buffers())
49
+ for key in m_param:
50
+ if m_param[key].requires_grad:
51
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52
+ else:
53
+ assert not key in self.m_name2s_name
54
+
55
+ def store(self, parameters):
56
+ """
57
+ Save the current parameters for restoring later.
58
+ Args:
59
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60
+ temporarily stored.
61
+ """
62
+ self.collected_params = [param.clone() for param in parameters]
63
+
64
+ def restore(self, parameters):
65
+ """
66
+ Restore the parameters stored with the `store` method.
67
+ Useful to validate the model with EMA parameters without affecting the
68
+ original optimization process. Store the parameters before the
69
+ `copy_to` method. After validation (or model saving), use this to
70
+ restore the former parameters.
71
+ Args:
72
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73
+ updated with the stored parameters.
74
+ """
75
+ for c_param, param in zip(self.collected_params, parameters):
76
+ param.data.copy_(c_param.data)
swim/modules/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from swim.modules.losses.contperceptual import LPIPSWithDiscriminator
swim/modules/losses/contperceptual.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from lpips import LPIPS
5
+
6
+ from swim.modules.discriminators.n_layer_discriminator import (
7
+ NLayerDiscriminator,
8
+ weights_init,
9
+ )
10
+
11
+
12
+ def adopt_weight(weight, global_step, threshold=0, value=0.0):
13
+ if global_step < threshold:
14
+ weight = value
15
+ return weight
16
+
17
+
18
+ def hinge_d_loss(logits_real, logits_fake):
19
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
20
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
21
+ d_loss = 0.5 * (loss_real + loss_fake)
22
+ return d_loss
23
+
24
+
25
+ def vanilla_d_loss(logits_real, logits_fake):
26
+ d_loss = 0.5 * (
27
+ torch.mean(torch.nn.functional.softplus(-logits_real))
28
+ + torch.mean(torch.nn.functional.softplus(logits_fake))
29
+ )
30
+ return d_loss
31
+
32
+
33
+ class LPIPSWithDiscriminator(nn.Module):
34
+ def __init__(
35
+ self,
36
+ disc_start,
37
+ logvar_init=0.0,
38
+ kl_weight=1.0,
39
+ pixelloss_weight=1.0,
40
+ disc_num_layers=3,
41
+ disc_in_channels=3,
42
+ disc_factor=1.0,
43
+ disc_weight=1.0,
44
+ perceptual_weight=1.0,
45
+ use_actnorm=False,
46
+ disc_conditional=False,
47
+ disc_loss="hinge",
48
+ ):
49
+
50
+ super().__init__()
51
+ assert disc_loss in ["hinge", "vanilla"]
52
+ self.kl_weight = kl_weight
53
+ self.pixel_weight = pixelloss_weight
54
+ self.perceptual_loss = LPIPS(net="vgg").eval()
55
+ self.perceptual_weight = perceptual_weight
56
+ # output log variance
57
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
58
+
59
+ self.discriminator = NLayerDiscriminator(
60
+ input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
61
+ ).apply(weights_init)
62
+ self.discriminator_iter_start = disc_start
63
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
64
+ self.disc_factor = disc_factor
65
+ self.discriminator_weight = disc_weight
66
+ self.disc_conditional = disc_conditional
67
+
68
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
69
+ if last_layer is not None:
70
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
71
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
72
+ else:
73
+ nll_grads = torch.autograd.grad(
74
+ nll_loss, self.last_layer[0], retain_graph=True
75
+ )[0]
76
+ g_grads = torch.autograd.grad(
77
+ g_loss, self.last_layer[0], retain_graph=True
78
+ )[0]
79
+
80
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
81
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
82
+ d_weight = d_weight * self.discriminator_weight
83
+ return d_weight
84
+
85
+ def forward(
86
+ self,
87
+ inputs,
88
+ reconstructions,
89
+ posteriors,
90
+ optimizer_idx,
91
+ global_step,
92
+ last_layer=None,
93
+ cond=None,
94
+ split="train",
95
+ weights=None,
96
+ ):
97
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
98
+ if self.perceptual_weight > 0:
99
+ p_loss = self.perceptual_loss(
100
+ inputs.contiguous(), reconstructions.contiguous()
101
+ )
102
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
103
+
104
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
105
+ weighted_nll_loss = nll_loss
106
+ if weights is not None:
107
+ weighted_nll_loss = weights * nll_loss
108
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
109
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
110
+ kl_loss = posteriors.kl()
111
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
112
+
113
+ # now the GAN part
114
+ if optimizer_idx == 0:
115
+ # generator update
116
+ if cond is None:
117
+ assert not self.disc_conditional
118
+ logits_fake = self.discriminator(reconstructions.contiguous())
119
+ else:
120
+ assert self.disc_conditional
121
+ logits_fake = self.discriminator(
122
+ torch.cat((reconstructions.contiguous(), cond), dim=1)
123
+ )
124
+ g_loss = -torch.mean(logits_fake)
125
+
126
+ if self.disc_factor > 0.0:
127
+ try:
128
+ d_weight = self.calculate_adaptive_weight(
129
+ nll_loss, g_loss, last_layer=last_layer
130
+ )
131
+ except RuntimeError:
132
+ assert not self.training
133
+ d_weight = torch.tensor(0.0)
134
+ else:
135
+ d_weight = torch.tensor(0.0)
136
+
137
+ disc_factor = adopt_weight(
138
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
139
+ )
140
+ loss = (
141
+ weighted_nll_loss
142
+ + self.kl_weight * kl_loss
143
+ + d_weight * disc_factor * g_loss
144
+ )
145
+
146
+ log = {
147
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
148
+ "{}/logvar".format(split): self.logvar.detach(),
149
+ "{}/kl_loss".format(split): kl_loss.detach().mean(),
150
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
151
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
152
+ "{}/d_weight".format(split): d_weight.detach(),
153
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
154
+ "{}/g_loss".format(split): g_loss.detach().mean(),
155
+ }
156
+ return loss, log
157
+
158
+ if optimizer_idx == 1:
159
+ # second pass for discriminator update
160
+ if cond is None:
161
+ logits_real = self.discriminator(inputs.contiguous().detach())
162
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
163
+ else:
164
+ logits_real = self.discriminator(
165
+ torch.cat((inputs.contiguous().detach(), cond), dim=1)
166
+ )
167
+ logits_fake = self.discriminator(
168
+ torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
169
+ )
170
+
171
+ disc_factor = adopt_weight(
172
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
173
+ )
174
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
175
+
176
+ log = {
177
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
178
+ "{}/logits_real".format(split): logits_real.detach().mean(),
179
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
180
+ }
181
+ return d_loss, log
swim/modules/x_transformer.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
2
+ import torch
3
+ from torch import nn, einsum
4
+ import torch.nn.functional as F
5
+ from functools import partial
6
+ from inspect import isfunction
7
+ from collections import namedtuple
8
+ from einops import rearrange, repeat, reduce
9
+
10
+ # constants
11
+
12
+ DEFAULT_DIM_HEAD = 64
13
+
14
+ Intermediates = namedtuple('Intermediates', [
15
+ 'pre_softmax_attn',
16
+ 'post_softmax_attn'
17
+ ])
18
+
19
+ LayerIntermediates = namedtuple('Intermediates', [
20
+ 'hiddens',
21
+ 'attn_intermediates'
22
+ ])
23
+
24
+
25
+ class AbsolutePositionalEmbedding(nn.Module):
26
+ def __init__(self, dim, max_seq_len):
27
+ super().__init__()
28
+ self.emb = nn.Embedding(max_seq_len, dim)
29
+ self.init_()
30
+
31
+ def init_(self):
32
+ nn.init.normal_(self.emb.weight, std=0.02)
33
+
34
+ def forward(self, x):
35
+ n = torch.arange(x.shape[1], device=x.device)
36
+ return self.emb(n)[None, :, :]
37
+
38
+
39
+ class FixedPositionalEmbedding(nn.Module):
40
+ def __init__(self, dim):
41
+ super().__init__()
42
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
43
+ self.register_buffer('inv_freq', inv_freq)
44
+
45
+ def forward(self, x, seq_dim=1, offset=0):
46
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
47
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
48
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
49
+ return emb[None, :, :]
50
+
51
+
52
+ # helpers
53
+
54
+ def exists(val):
55
+ return val is not None
56
+
57
+
58
+ def default(val, d):
59
+ if exists(val):
60
+ return val
61
+ return d() if isfunction(d) else d
62
+
63
+
64
+ def always(val):
65
+ def inner(*args, **kwargs):
66
+ return val
67
+ return inner
68
+
69
+
70
+ def not_equals(val):
71
+ def inner(x):
72
+ return x != val
73
+ return inner
74
+
75
+
76
+ def equals(val):
77
+ def inner(x):
78
+ return x == val
79
+ return inner
80
+
81
+
82
+ def max_neg_value(tensor):
83
+ return -torch.finfo(tensor.dtype).max
84
+
85
+
86
+ # keyword argument helpers
87
+
88
+ def pick_and_pop(keys, d):
89
+ values = list(map(lambda key: d.pop(key), keys))
90
+ return dict(zip(keys, values))
91
+
92
+
93
+ def group_dict_by_key(cond, d):
94
+ return_val = [dict(), dict()]
95
+ for key in d.keys():
96
+ match = bool(cond(key))
97
+ ind = int(not match)
98
+ return_val[ind][key] = d[key]
99
+ return (*return_val,)
100
+
101
+
102
+ def string_begins_with(prefix, str):
103
+ return str.startswith(prefix)
104
+
105
+
106
+ def group_by_key_prefix(prefix, d):
107
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
108
+
109
+
110
+ def groupby_prefix_and_trim(prefix, d):
111
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
112
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
113
+ return kwargs_without_prefix, kwargs
114
+
115
+
116
+ # classes
117
+ class Scale(nn.Module):
118
+ def __init__(self, value, fn):
119
+ super().__init__()
120
+ self.value = value
121
+ self.fn = fn
122
+
123
+ def forward(self, x, **kwargs):
124
+ x, *rest = self.fn(x, **kwargs)
125
+ return (x * self.value, *rest)
126
+
127
+
128
+ class Rezero(nn.Module):
129
+ def __init__(self, fn):
130
+ super().__init__()
131
+ self.fn = fn
132
+ self.g = nn.Parameter(torch.zeros(1))
133
+
134
+ def forward(self, x, **kwargs):
135
+ x, *rest = self.fn(x, **kwargs)
136
+ return (x * self.g, *rest)
137
+
138
+
139
+ class ScaleNorm(nn.Module):
140
+ def __init__(self, dim, eps=1e-5):
141
+ super().__init__()
142
+ self.scale = dim ** -0.5
143
+ self.eps = eps
144
+ self.g = nn.Parameter(torch.ones(1))
145
+
146
+ def forward(self, x):
147
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
148
+ return x / norm.clamp(min=self.eps) * self.g
149
+
150
+
151
+ class RMSNorm(nn.Module):
152
+ def __init__(self, dim, eps=1e-8):
153
+ super().__init__()
154
+ self.scale = dim ** -0.5
155
+ self.eps = eps
156
+ self.g = nn.Parameter(torch.ones(dim))
157
+
158
+ def forward(self, x):
159
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
160
+ return x / norm.clamp(min=self.eps) * self.g
161
+
162
+
163
+ class Residual(nn.Module):
164
+ def forward(self, x, residual):
165
+ return x + residual
166
+
167
+
168
+ class GRUGating(nn.Module):
169
+ def __init__(self, dim):
170
+ super().__init__()
171
+ self.gru = nn.GRUCell(dim, dim)
172
+
173
+ def forward(self, x, residual):
174
+ gated_output = self.gru(
175
+ rearrange(x, 'b n d -> (b n) d'),
176
+ rearrange(residual, 'b n d -> (b n) d')
177
+ )
178
+
179
+ return gated_output.reshape_as(x)
180
+
181
+
182
+ # feedforward
183
+
184
+ class GEGLU(nn.Module):
185
+ def __init__(self, dim_in, dim_out):
186
+ super().__init__()
187
+ self.proj = nn.Linear(dim_in, dim_out * 2)
188
+
189
+ def forward(self, x):
190
+ x, gate = self.proj(x).chunk(2, dim=-1)
191
+ return x * F.gelu(gate)
192
+
193
+
194
+ class FeedForward(nn.Module):
195
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
196
+ super().__init__()
197
+ inner_dim = int(dim * mult)
198
+ dim_out = default(dim_out, dim)
199
+ project_in = nn.Sequential(
200
+ nn.Linear(dim, inner_dim),
201
+ nn.GELU()
202
+ ) if not glu else GEGLU(dim, inner_dim)
203
+
204
+ self.net = nn.Sequential(
205
+ project_in,
206
+ nn.Dropout(dropout),
207
+ nn.Linear(inner_dim, dim_out)
208
+ )
209
+
210
+ def forward(self, x):
211
+ return self.net(x)
212
+
213
+
214
+ # attention.
215
+ class Attention(nn.Module):
216
+ def __init__(
217
+ self,
218
+ dim,
219
+ dim_head=DEFAULT_DIM_HEAD,
220
+ heads=8,
221
+ causal=False,
222
+ mask=None,
223
+ talking_heads=False,
224
+ sparse_topk=None,
225
+ use_entmax15=False,
226
+ num_mem_kv=0,
227
+ dropout=0.,
228
+ on_attn=False
229
+ ):
230
+ super().__init__()
231
+ if use_entmax15:
232
+ raise NotImplementedError("Check out entmax activation instead of softmax activation!")
233
+ self.scale = dim_head ** -0.5
234
+ self.heads = heads
235
+ self.causal = causal
236
+ self.mask = mask
237
+
238
+ inner_dim = dim_head * heads
239
+
240
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
241
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
242
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
243
+ self.dropout = nn.Dropout(dropout)
244
+
245
+ # talking heads
246
+ self.talking_heads = talking_heads
247
+ if talking_heads:
248
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
249
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
250
+
251
+ # explicit topk sparse attention
252
+ self.sparse_topk = sparse_topk
253
+
254
+ # entmax
255
+ #self.attn_fn = entmax15 if use_entmax15 else F.softmax
256
+ self.attn_fn = F.softmax
257
+
258
+ # add memory key / values
259
+ self.num_mem_kv = num_mem_kv
260
+ if num_mem_kv > 0:
261
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
262
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
263
+
264
+ # attention on attention
265
+ self.attn_on_attn = on_attn
266
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
267
+
268
+ def forward(
269
+ self,
270
+ x,
271
+ context=None,
272
+ mask=None,
273
+ context_mask=None,
274
+ rel_pos=None,
275
+ sinusoidal_emb=None,
276
+ prev_attn=None,
277
+ mem=None
278
+ ):
279
+ b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
280
+ kv_input = default(context, x)
281
+
282
+ q_input = x
283
+ k_input = kv_input
284
+ v_input = kv_input
285
+
286
+ if exists(mem):
287
+ k_input = torch.cat((mem, k_input), dim=-2)
288
+ v_input = torch.cat((mem, v_input), dim=-2)
289
+
290
+ if exists(sinusoidal_emb):
291
+ # in shortformer, the query would start at a position offset depending on the past cached memory
292
+ offset = k_input.shape[-2] - q_input.shape[-2]
293
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
294
+ k_input = k_input + sinusoidal_emb(k_input)
295
+
296
+ q = self.to_q(q_input)
297
+ k = self.to_k(k_input)
298
+ v = self.to_v(v_input)
299
+
300
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
301
+
302
+ input_mask = None
303
+ if any(map(exists, (mask, context_mask))):
304
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
305
+ k_mask = q_mask if not exists(context) else context_mask
306
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
307
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
308
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
309
+ input_mask = q_mask * k_mask
310
+
311
+ if self.num_mem_kv > 0:
312
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
313
+ k = torch.cat((mem_k, k), dim=-2)
314
+ v = torch.cat((mem_v, v), dim=-2)
315
+ if exists(input_mask):
316
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
317
+
318
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
319
+ mask_value = max_neg_value(dots)
320
+
321
+ if exists(prev_attn):
322
+ dots = dots + prev_attn
323
+
324
+ pre_softmax_attn = dots
325
+
326
+ if talking_heads:
327
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
328
+
329
+ if exists(rel_pos):
330
+ dots = rel_pos(dots)
331
+
332
+ if exists(input_mask):
333
+ dots.masked_fill_(~input_mask, mask_value)
334
+ del input_mask
335
+
336
+ if self.causal:
337
+ i, j = dots.shape[-2:]
338
+ r = torch.arange(i, device=device)
339
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
340
+ mask = F.pad(mask, (j - i, 0), value=False)
341
+ dots.masked_fill_(mask, mask_value)
342
+ del mask
343
+
344
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
345
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
346
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
347
+ mask = dots < vk
348
+ dots.masked_fill_(mask, mask_value)
349
+ del mask
350
+
351
+ attn = self.attn_fn(dots, dim=-1)
352
+ post_softmax_attn = attn
353
+
354
+ attn = self.dropout(attn)
355
+
356
+ if talking_heads:
357
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
358
+
359
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
360
+ out = rearrange(out, 'b h n d -> b n (h d)')
361
+
362
+ intermediates = Intermediates(
363
+ pre_softmax_attn=pre_softmax_attn,
364
+ post_softmax_attn=post_softmax_attn
365
+ )
366
+
367
+ return self.to_out(out), intermediates
368
+
369
+
370
+ class AttentionLayers(nn.Module):
371
+ def __init__(
372
+ self,
373
+ dim,
374
+ depth,
375
+ heads=8,
376
+ causal=False,
377
+ cross_attend=False,
378
+ only_cross=False,
379
+ use_scalenorm=False,
380
+ use_rmsnorm=False,
381
+ use_rezero=False,
382
+ rel_pos_num_buckets=32,
383
+ rel_pos_max_distance=128,
384
+ position_infused_attn=False,
385
+ custom_layers=None,
386
+ sandwich_coef=None,
387
+ par_ratio=None,
388
+ residual_attn=False,
389
+ cross_residual_attn=False,
390
+ macaron=False,
391
+ pre_norm=True,
392
+ gate_residual=False,
393
+ **kwargs
394
+ ):
395
+ super().__init__()
396
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
397
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
398
+
399
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
400
+
401
+ self.dim = dim
402
+ self.depth = depth
403
+ self.layers = nn.ModuleList([])
404
+
405
+ self.has_pos_emb = position_infused_attn
406
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
407
+ self.rotary_pos_emb = always(None)
408
+
409
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
410
+ self.rel_pos = None
411
+
412
+ self.pre_norm = pre_norm
413
+
414
+ self.residual_attn = residual_attn
415
+ self.cross_residual_attn = cross_residual_attn
416
+
417
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
418
+ norm_class = RMSNorm if use_rmsnorm else norm_class
419
+ norm_fn = partial(norm_class, dim)
420
+
421
+ norm_fn = nn.Identity if use_rezero else norm_fn
422
+ branch_fn = Rezero if use_rezero else None
423
+
424
+ if cross_attend and not only_cross:
425
+ default_block = ('a', 'c', 'f')
426
+ elif cross_attend and only_cross:
427
+ default_block = ('c', 'f')
428
+ else:
429
+ default_block = ('a', 'f')
430
+
431
+ if macaron:
432
+ default_block = ('f',) + default_block
433
+
434
+ if exists(custom_layers):
435
+ layer_types = custom_layers
436
+ elif exists(par_ratio):
437
+ par_depth = depth * len(default_block)
438
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
439
+ default_block = tuple(filter(not_equals('f'), default_block))
440
+ par_attn = par_depth // par_ratio
441
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
442
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
443
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
444
+ par_block = default_block + ('f',) * (par_width - len(default_block))
445
+ par_head = par_block * par_attn
446
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
447
+ elif exists(sandwich_coef):
448
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
449
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
450
+ else:
451
+ layer_types = default_block * depth
452
+
453
+ self.layer_types = layer_types
454
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
455
+
456
+ for layer_type in self.layer_types:
457
+ if layer_type == 'a':
458
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
459
+ elif layer_type == 'c':
460
+ layer = Attention(dim, heads=heads, **attn_kwargs)
461
+ elif layer_type == 'f':
462
+ layer = FeedForward(dim, **ff_kwargs)
463
+ layer = layer if not macaron else Scale(0.5, layer)
464
+ else:
465
+ raise Exception(f'invalid layer type {layer_type}')
466
+
467
+ if isinstance(layer, Attention) and exists(branch_fn):
468
+ layer = branch_fn(layer)
469
+
470
+ if gate_residual:
471
+ residual_fn = GRUGating(dim)
472
+ else:
473
+ residual_fn = Residual()
474
+
475
+ self.layers.append(nn.ModuleList([
476
+ norm_fn(),
477
+ layer,
478
+ residual_fn
479
+ ]))
480
+
481
+ def forward(
482
+ self,
483
+ x,
484
+ context=None,
485
+ mask=None,
486
+ context_mask=None,
487
+ mems=None,
488
+ return_hiddens=False
489
+ ):
490
+ hiddens = []
491
+ intermediates = []
492
+ prev_attn = None
493
+ prev_cross_attn = None
494
+
495
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
496
+
497
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
498
+ is_last = ind == (len(self.layers) - 1)
499
+
500
+ if layer_type == 'a':
501
+ hiddens.append(x)
502
+ layer_mem = mems.pop(0)
503
+
504
+ residual = x
505
+
506
+ if self.pre_norm:
507
+ x = norm(x)
508
+
509
+ if layer_type == 'a':
510
+ out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
511
+ prev_attn=prev_attn, mem=layer_mem)
512
+ elif layer_type == 'c':
513
+ out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
514
+ elif layer_type == 'f':
515
+ out = block(x)
516
+
517
+ x = residual_fn(out, residual)
518
+
519
+ if layer_type in ('a', 'c'):
520
+ intermediates.append(inter)
521
+
522
+ if layer_type == 'a' and self.residual_attn:
523
+ prev_attn = inter.pre_softmax_attn
524
+ elif layer_type == 'c' and self.cross_residual_attn:
525
+ prev_cross_attn = inter.pre_softmax_attn
526
+
527
+ if not self.pre_norm and not is_last:
528
+ x = norm(x)
529
+
530
+ if return_hiddens:
531
+ intermediates = LayerIntermediates(
532
+ hiddens=hiddens,
533
+ attn_intermediates=intermediates
534
+ )
535
+
536
+ return x, intermediates
537
+
538
+ return x
539
+
540
+
541
+ class Encoder(AttentionLayers):
542
+ def __init__(self, **kwargs):
543
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
544
+ super().__init__(causal=False, **kwargs)
545
+
546
+
547
+
548
+ class TransformerWrapper(nn.Module):
549
+ def __init__(
550
+ self,
551
+ *,
552
+ num_tokens,
553
+ max_seq_len,
554
+ attn_layers,
555
+ emb_dim=None,
556
+ max_mem_len=0.,
557
+ emb_dropout=0.,
558
+ num_memory_tokens=None,
559
+ tie_embedding=False,
560
+ use_pos_emb=True
561
+ ):
562
+ super().__init__()
563
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
564
+
565
+ dim = attn_layers.dim
566
+ emb_dim = default(emb_dim, dim)
567
+
568
+ self.max_seq_len = max_seq_len
569
+ self.max_mem_len = max_mem_len
570
+ self.num_tokens = num_tokens
571
+
572
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
573
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
574
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
575
+ self.emb_dropout = nn.Dropout(emb_dropout)
576
+
577
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
578
+ self.attn_layers = attn_layers
579
+ self.norm = nn.LayerNorm(dim)
580
+
581
+ self.init_()
582
+
583
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
584
+
585
+ # memory tokens (like [cls]) from Memory Transformers paper
586
+ num_memory_tokens = default(num_memory_tokens, 0)
587
+ self.num_memory_tokens = num_memory_tokens
588
+ if num_memory_tokens > 0:
589
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
590
+
591
+ # let funnel encoder know number of memory tokens, if specified
592
+ if hasattr(attn_layers, 'num_memory_tokens'):
593
+ attn_layers.num_memory_tokens = num_memory_tokens
594
+
595
+ def init_(self):
596
+ nn.init.normal_(self.token_emb.weight, std=0.02)
597
+
598
+ def forward(
599
+ self,
600
+ x,
601
+ return_embeddings=False,
602
+ mask=None,
603
+ return_mems=False,
604
+ return_attn=False,
605
+ mems=None,
606
+ **kwargs
607
+ ):
608
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
609
+ x = self.token_emb(x)
610
+ x += self.pos_emb(x)
611
+ x = self.emb_dropout(x)
612
+
613
+ x = self.project_emb(x)
614
+
615
+ if num_mem > 0:
616
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
617
+ x = torch.cat((mem, x), dim=1)
618
+
619
+ # auto-handle masking after appending memory tokens
620
+ if exists(mask):
621
+ mask = F.pad(mask, (num_mem, 0), value=True)
622
+
623
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
624
+ x = self.norm(x)
625
+
626
+ mem, x = x[:, :num_mem], x[:, num_mem:]
627
+
628
+ out = self.to_logits(x) if not return_embeddings else x
629
+
630
+ if return_mems:
631
+ hiddens = intermediates.hiddens
632
+ new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
633
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
634
+ return out, new_mems
635
+
636
+ if return_attn:
637
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
638
+ return out, attn_maps
639
+
640
+ return out
641
+
swim/unet.py DELETED
@@ -1,169 +0,0 @@
1
- import math
2
- from typing import List
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- from .attention_blocks import SpatialTransformer
9
- from .blocks import (
10
- DownSample,
11
- ResnetBlock,
12
- TimestepEmbedSequential,
13
- UpSample,
14
- Normalization,
15
- get_timestep_embedding,
16
- )
17
-
18
-
19
- class UNet(nn.Module):
20
- """
21
- ## U-Net model
22
- """
23
-
24
- def __init__(
25
- self,
26
- *,
27
- in_channels: int,
28
- out_channels: int,
29
- channels: int,
30
- n_res_blocks: int,
31
- attention_levels: List[int],
32
- channel_multipliers: List[int],
33
- n_heads: int,
34
- tf_layers: int = 1,
35
- d_cond: int = 768
36
- ):
37
- """
38
- :param in_channels: is the number of channels in the input feature map
39
- :param out_channels: is the number of channels in the output feature map
40
- :param channels: is the base channel count for the model
41
- :param n_res_blocks: number of residual blocks at each level
42
- :param attention_levels: are the levels at which attention should be performed
43
- :param channel_multipliers: are the multiplicative factors for number of channels for each level
44
- :param n_heads: is the number of attention heads in the transformers
45
- :param tf_layers: is the number of transformer layers in the transformers
46
- :param d_cond: is the size of the conditional embedding in the transformers
47
- """
48
- super().__init__()
49
- self.channels = channels
50
-
51
- # Number of levels
52
- levels = len(channel_multipliers)
53
- # Size time embeddings
54
- d_time_emb = channels * 4
55
- self.time_embed = nn.Sequential(
56
- nn.Linear(channels, d_time_emb),
57
- nn.SiLU(),
58
- nn.Linear(d_time_emb, d_time_emb),
59
- )
60
-
61
- # Input half of the U-Net
62
- self.input_blocks = nn.ModuleList()
63
- # Initial $3 \times 3$ convolution that maps the input to `channels`.
64
- # The blocks are wrapped in `TimestepEmbedSequential` module because
65
- # different modules have different forward function signatures;
66
- # for example, convolution only accepts the feature map and
67
- # residual blocks accept the feature map and time embedding.
68
- # `TimestepEmbedSequential` calls them accordingly.
69
- self.input_blocks.append(
70
- TimestepEmbedSequential(nn.Conv2d(in_channels, channels, 3, padding=1))
71
- )
72
- # Number of channels at each block in the input half of U-Net
73
- input_block_channels = [channels]
74
- # Number of channels at each level
75
- channels_list = [channels * m for m in channel_multipliers]
76
- # Prepare levels
77
- for i in range(levels):
78
- # Add the residual blocks and attentions
79
- for _ in range(n_res_blocks):
80
- # Residual block maps from previous number of channels to the number of
81
- # channels in the current level
82
- layers = [
83
- ResnetBlock(channels, d_time_emb, out_channels=channels_list[i])
84
- ]
85
- channels = channels_list[i]
86
- # Add transformer
87
- if i in attention_levels:
88
- layers.append(
89
- SpatialTransformer(channels, n_heads, tf_layers, d_cond)
90
- )
91
- # Add them to the input half of the U-Net and keep track of the number of channels of
92
- # its output
93
- self.input_blocks.append(TimestepEmbedSequential(*layers))
94
- input_block_channels.append(channels)
95
- # Down sample at all levels except last
96
- if i != levels - 1:
97
- self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
98
- input_block_channels.append(channels)
99
-
100
- # The middle of the U-Net
101
- self.middle_block = TimestepEmbedSequential(
102
- ResnetBlock(channels, d_time_emb),
103
- SpatialTransformer(channels, n_heads, tf_layers, d_cond),
104
- ResnetBlock(channels, d_time_emb),
105
- )
106
-
107
- # Second half of the U-Net
108
- self.output_blocks = nn.ModuleList([])
109
- # Prepare levels in reverse order
110
- for i in reversed(range(levels)):
111
- # Add the residual blocks and attentions
112
- for j in range(n_res_blocks + 1):
113
- # Residual block maps from previous number of channels plus the
114
- # skip connections from the input half of U-Net to the number of
115
- # channels in the current level.
116
- layers = [
117
- ResnetBlock(
118
- channels + input_block_channels.pop(),
119
- d_time_emb,
120
- out_channels=channels_list[i],
121
- )
122
- ]
123
- channels = channels_list[i]
124
- # Add transformer
125
- if i in attention_levels:
126
- layers.append(
127
- SpatialTransformer(channels, n_heads, tf_layers, d_cond)
128
- )
129
- # Up-sample at every level after last residual block
130
- # except the last one.
131
- # Note that we are iterating in reverse; i.e. `i == 0` is the last.
132
- if i != 0 and j == n_res_blocks:
133
- layers.append(UpSample(channels))
134
- # Add to the output half of the U-Net
135
- self.output_blocks.append(TimestepEmbedSequential(*layers))
136
-
137
- # Final normalization and $3 \times 3$ convolution
138
- self.out = nn.Sequential(
139
- Normalization(channels),
140
- nn.SiLU(),
141
- nn.Conv2d(channels, out_channels, 3, padding=1),
142
- )
143
-
144
- def forward(self, x: torch.Tensor, timesteps: torch.Tensor, cond: torch.Tensor):
145
- """
146
- :param x: is the input feature map of shape `[batch_size, channels, width, height]`
147
- :param timesteps: are the time steps of shape `[batch_size]`
148
- :param cond: conditioning of shape `[batch_size, n_cond, d_cond]`
149
- """
150
- # To store the input half outputs for skip connections
151
- x_input_block = []
152
-
153
- # Get time step embeddings
154
- t_emb = get_timestep_embedding(timesteps, self.channels * 2)
155
- t_emb = self.time_embed(t_emb)
156
-
157
- # Input half of the U-Net
158
- for module in self.input_blocks:
159
- x = module(x, t_emb, cond)
160
- x_input_block.append(x)
161
- # Middle of the U-Net
162
- x = self.middle_block(x, t_emb, cond)
163
- # Output half of the U-Net
164
- for module in self.output_blocks:
165
- x = torch.cat([x, x_input_block.pop()], dim=1)
166
- x = module(x, t_emb, cond)
167
-
168
- # Final normalization and $3 \times 3$ convolution
169
- return self.out(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
swim/utils.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import torch
4
+ from torch import nn
5
+ import numpy as np
6
+ from collections import abc
7
+ from einops import rearrange
8
+ from functools import partial
9
+
10
+ import multiprocessing as mp
11
+ from threading import Thread
12
+ from queue import Queue
13
+
14
+ from inspect import isfunction
15
+ from PIL import Image, ImageDraw, ImageFont
16
+
17
+
18
+ def log_txt_as_img(wh, xc, size=10):
19
+ # wh a tuple of (width, height)
20
+ # xc a list of captions to plot
21
+ b = len(xc)
22
+ txts = list()
23
+ for bi in range(b):
24
+ txt = Image.new("RGB", wh, color="white")
25
+ draw = ImageDraw.Draw(txt)
26
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
27
+ nc = int(40 * (wh[0] / 256))
28
+ lines = "\n".join(
29
+ xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
30
+ )
31
+
32
+ try:
33
+ draw.text((0, 0), lines, fill="black", font=font)
34
+ except UnicodeEncodeError:
35
+ print("Cant encode string for logging. Skipping.")
36
+
37
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
38
+ txts.append(txt)
39
+ txts = np.stack(txts)
40
+ txts = torch.tensor(txts)
41
+ return txts
42
+
43
+
44
+ def ismap(x):
45
+ if not isinstance(x, torch.Tensor):
46
+ return False
47
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
48
+
49
+
50
+ def isimage(x):
51
+ if not isinstance(x, torch.Tensor):
52
+ return False
53
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
54
+
55
+
56
+ def exists(x):
57
+ return x is not None
58
+
59
+
60
+ def default(val, d):
61
+ if exists(val):
62
+ return val
63
+ return d() if isfunction(d) else d
64
+
65
+
66
+ def mean_flat(tensor):
67
+ """
68
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
69
+ Take the mean over all non-batch dimensions.
70
+ """
71
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
72
+
73
+
74
+ def count_params(model, verbose=False):
75
+ total_params = sum(p.numel() for p in model.parameters())
76
+ if verbose:
77
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
78
+ return total_params
79
+
80
+
81
+ def instantiate_from_config(config):
82
+ if not "target" in config:
83
+ if config == "__is_first_stage__":
84
+ return None
85
+ elif config == "__is_unconditional__":
86
+ return None
87
+ raise KeyError("Expected key `target` to instantiate.")
88
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
89
+
90
+
91
+ def get_obj_from_str(string, reload=False):
92
+ module, cls = string.rsplit(".", 1)
93
+ if reload:
94
+ module_imp = importlib.import_module(module)
95
+ importlib.reload(module_imp)
96
+ return getattr(importlib.import_module(module, package=None), cls)
97
+
98
+
99
+ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
100
+ # create dummy dataset instance
101
+
102
+ # run prefetching
103
+ if idx_to_fn:
104
+ res = func(data, worker_id=idx)
105
+ else:
106
+ res = func(data)
107
+ Q.put([idx, res])
108
+ Q.put("Done")
109
+
110
+
111
+ def parallel_data_prefetch(
112
+ func: callable,
113
+ data,
114
+ n_proc,
115
+ target_data_type="ndarray",
116
+ cpu_intensive=True,
117
+ use_worker_id=False,
118
+ ):
119
+ # if target_data_type not in ["ndarray", "list"]:
120
+ # raise ValueError(
121
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
122
+ # )
123
+ if isinstance(data, np.ndarray) and target_data_type == "list":
124
+ raise ValueError("list expected but function got ndarray.")
125
+ elif isinstance(data, abc.Iterable):
126
+ if isinstance(data, dict):
127
+ print(
128
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
129
+ )
130
+ data = list(data.values())
131
+ if target_data_type == "ndarray":
132
+ data = np.asarray(data)
133
+ else:
134
+ data = list(data)
135
+ else:
136
+ raise TypeError(
137
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
138
+ )
139
+
140
+ if cpu_intensive:
141
+ Q = mp.Queue(1000)
142
+ proc = mp.Process
143
+ else:
144
+ Q = Queue(1000)
145
+ proc = Thread
146
+ # spawn processes
147
+ if target_data_type == "ndarray":
148
+ arguments = [
149
+ [func, Q, part, i, use_worker_id]
150
+ for i, part in enumerate(np.array_split(data, n_proc))
151
+ ]
152
+ else:
153
+ step = (
154
+ int(len(data) / n_proc + 1)
155
+ if len(data) % n_proc != 0
156
+ else int(len(data) / n_proc)
157
+ )
158
+ arguments = [
159
+ [func, Q, part, i, use_worker_id]
160
+ for i, part in enumerate(
161
+ [data[i : i + step] for i in range(0, len(data), step)]
162
+ )
163
+ ]
164
+ processes = []
165
+ for i in range(n_proc):
166
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
167
+ processes += [p]
168
+
169
+ # start processes
170
+ print(f"Start prefetching...")
171
+ import time
172
+
173
+ start = time.time()
174
+ gather_res = [[] for _ in range(n_proc)]
175
+ try:
176
+ for p in processes:
177
+ p.start()
178
+
179
+ k = 0
180
+ while k < n_proc:
181
+ # get result
182
+ res = Q.get()
183
+ if res == "Done":
184
+ k += 1
185
+ else:
186
+ gather_res[res[0]] = res[1]
187
+
188
+ except Exception as e:
189
+ print("Exception: ", e)
190
+ for p in processes:
191
+ p.terminate()
192
+
193
+ raise e
194
+ finally:
195
+ for p in processes:
196
+ p.join()
197
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
198
+
199
+ if target_data_type == "ndarray":
200
+ if not isinstance(gather_res[0], np.ndarray):
201
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
202
+
203
+ # order outputs
204
+ return np.concatenate(gather_res, axis=0)
205
+ elif target_data_type == "list":
206
+ out = []
207
+ for r in gather_res:
208
+ out.extend(r)
209
+ return out
210
+ else:
211
+ return gather_res
212
+
213
+
214
+ class ActNorm(nn.Module):
215
+ def __init__(
216
+ self, num_features, logdet=False, affine=True, allow_reverse_init=False
217
+ ):
218
+ assert affine
219
+ super().__init__()
220
+ self.logdet = logdet
221
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
222
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
223
+ self.allow_reverse_init = allow_reverse_init
224
+
225
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
226
+
227
+ def initialize(self, input):
228
+ with torch.no_grad():
229
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
230
+ mean = (
231
+ flatten.mean(1)
232
+ .unsqueeze(1)
233
+ .unsqueeze(2)
234
+ .unsqueeze(3)
235
+ .permute(1, 0, 2, 3)
236
+ )
237
+ std = (
238
+ flatten.std(1)
239
+ .unsqueeze(1)
240
+ .unsqueeze(2)
241
+ .unsqueeze(3)
242
+ .permute(1, 0, 2, 3)
243
+ )
244
+
245
+ self.loc.data.copy_(-mean)
246
+ self.scale.data.copy_(1 / (std + 1e-6))
247
+
248
+ def forward(self, input, reverse=False):
249
+ if reverse:
250
+ return self.reverse(input)
251
+ if len(input.shape) == 2:
252
+ input = input[:, :, None, None]
253
+ squeeze = True
254
+ else:
255
+ squeeze = False
256
+
257
+ _, _, height, width = input.shape
258
+
259
+ if self.training and self.initialized.item() == 0:
260
+ self.initialize(input)
261
+ self.initialized.fill_(1)
262
+
263
+ h = self.scale * (input + self.loc)
264
+
265
+ if squeeze:
266
+ h = h.squeeze(-1).squeeze(-1)
267
+
268
+ if self.logdet:
269
+ log_abs = torch.log(torch.abs(self.scale))
270
+ logdet = height * width * torch.sum(log_abs)
271
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
272
+ return h, logdet
273
+
274
+ return h
275
+
276
+ def reverse(self, output):
277
+ if self.training and self.initialized.item() == 0:
278
+ if not self.allow_reverse_init:
279
+ raise RuntimeError(
280
+ "Initializing ActNorm in reverse direction is "
281
+ "disabled by default. Use allow_reverse_init=True to enable."
282
+ )
283
+ else:
284
+ self.initialize(output)
285
+ self.initialized.fill_(1)
286
+
287
+ if len(output.shape) == 2:
288
+ output = output[:, :, None, None]
289
+ squeeze = True
290
+ else:
291
+ squeeze = False
292
+
293
+ h = output / self.scale - self.loc
294
+
295
+ if squeeze:
296
+ h = h.squeeze(-1).squeeze(-1)
297
+ return h
train.py DELETED
@@ -1,72 +0,0 @@
1
- import torch
2
- from torchinfo import summary
3
-
4
- from swim.autoencoder import Autoencoder
5
- from diffusers import AutoencoderKL, UNet2DModel
6
-
7
- # vae = Autoencoder(
8
- # z_channels=4,
9
- # in_channels=3,
10
- # channels=128,
11
- # channel_multipliers=[1, 2, 4, 4],
12
- # n_resnet_blocks=2,
13
- # emb_channels=4,
14
- # ).to("meta")
15
- # lol_vae = AutoencoderKL.from_pretrained(
16
- # "stabilityai/stable-diffusion-2-1", subfolder="vae"
17
- # ).to("meta")
18
-
19
- # # copy weights from lol_vae to vae
20
- # import json
21
-
22
- # with open("lolvae.json", "w") as f:
23
- # json.dump(list(lol_vae.state_dict().keys()), f, indent=4)
24
-
25
- # with open("vae.json", "w") as f:
26
- # json.dump(list(vae.state_dict().keys()), f, indent=4)
27
-
28
- # sample = torch.randn(1, 3, 512, 512).to("meta")
29
- # # lantent = vae.encoder(sample)
30
-
31
- from diffusers import UNet2DModel
32
-
33
- model = UNet2DModel(
34
- sample_size=512, # the target image resolution
35
- in_channels=3, # the number of input channels, 3 for RGB images
36
- out_channels=3, # the number of output channels
37
- layers_per_block=2, # how many ResNet layers to use per UNet block
38
- block_out_channels=(
39
- 128,
40
- 128,
41
- 256,
42
- 256,
43
- 512,
44
- 512,
45
- ), # the number of output channels for each UNet block
46
- down_block_types=(
47
- "DownBlock2D", # a regular ResNet downsampling block
48
- "DownBlock2D",
49
- "DownBlock2D",
50
- "DownBlock2D",
51
- "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
52
- "DownBlock2D",
53
- ),
54
- up_block_types=(
55
- "UpBlock2D", # a regular ResNet upsampling block
56
- "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
57
- "UpBlock2D",
58
- "UpBlock2D",
59
- "UpBlock2D",
60
- "UpBlock2D",
61
- ),
62
- ).to("meta")
63
-
64
- sample = torch.randn(1, 3, 512, 512).to("meta")
65
-
66
- summary(
67
- model,
68
- input_data=(
69
- sample,
70
- 0,
71
- ),
72
- )