Anton Forsman commited on
Commit
f04c9cc
1 Parent(s): 43fd0ed

put in everything

Browse files
Files changed (3) hide show
  1. app.py +23 -2
  2. model.py +658 -0
  3. requirements.txt +4 -1
app.py CHANGED
@@ -1,4 +1,25 @@
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a values')
4
- st.write(x, 'squared is', x * x)
 
1
  import streamlit as st
2
+ from PIL import Image
3
+ from inference import inference
4
+ import io
5
+
6
+ def main():
7
+ st.title("Image Display App")
8
+
9
+ # Button to trigger image generation
10
+ if st.button('Generate Image'):
11
+ # Call the function from inference.py
12
+ image = inference()
13
+
14
+ # Convert Pillow image to bytes for display in Streamlit
15
+ img_buffer = io.BytesIO()
16
+ image.save(img_buffer, format="PNG")
17
+ img_buffer.seek(0)
18
+
19
+ # Display the image
20
+ st.image(img_buffer, caption='Generated Image', use_column_width=True)
21
+
22
+ if __name__ == "__main__":
23
+ main()
24
+
25
 
 
 
model.py CHANGED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from collections import defaultdict
7
+ import torch as th
8
+ import numpy as np
9
+ import math
10
+ from tqdm import tqdm
11
+ from PIL import Image
12
+
13
+ class GaussianDiffusion:
14
+ def __init__(self, model, noise_steps, beta_0, beta_T, image_size, channels=3, schedule="linear"):
15
+ """
16
+ suggested betas for:
17
+ * linear schedule: 1e-4, 0.02
18
+
19
+ model: the model to be trained (nn.Module)
20
+ noise_steps: the number of steps to apply noise (int)
21
+ beta_0: the initial value of beta (float)
22
+ beta_T: the final value of beta (float)
23
+ image_size: the size of the image (int, int)
24
+ """
25
+ self.device = 'cpu'
26
+ self.channels = channels
27
+
28
+ self.model = model
29
+ self.noise_steps = noise_steps
30
+ self.beta_0 = beta_0
31
+ self.beta_T = beta_T
32
+ self.image_size = image_size
33
+
34
+ self.betas = self.beta_schedule(schedule=schedule)
35
+ self.alphas = 1.0 - self.betas
36
+ # cumulative product of alphas, so we can optimize forward process calculation
37
+ self.alpha_hat = torch.cumprod(self.alphas, dim=0)
38
+
39
+ def beta_schedule(self, schedule="cosine"):
40
+ if schedule == "linear":
41
+ return torch.linspace(self.beta_0, self.beta_T, self.noise_steps).to(self.device)
42
+ elif schedule == "cosine":
43
+ return self.betas_for_cosine(self.noise_steps)
44
+ elif schedule == "sigmoid":
45
+ return self.betas_for_sigmoid(self.noise_steps)
46
+
47
+ @staticmethod
48
+ def sigmoid(x):
49
+ return 1 / (1 + np.exp(-x))
50
+
51
+ def betas_for_sigmoid(self, num_diffusion_timesteps, start=-3,end=3, tau=1.0, clip_min = 1e-9):
52
+ betas = []
53
+ v_start = self.sigmoid(start/tau)
54
+ v_end = self.sigmoid(end/tau)
55
+ for t in range(num_diffusion_timesteps):
56
+ t_float = float(t/num_diffusion_timesteps)
57
+ output0 = self.sigmoid((t_float* (end-start)+start)/tau)
58
+ output = (v_end-output0) / (v_end-v_start)
59
+ betas.append(np.clip(output*.2, clip_min,.2))
60
+ return torch.flip(torch.tensor(betas).to(self.device),dims=[0]).float()
61
+
62
+ def betas_for_cosine(self,num_steps,start=0,end=1,tau=1,clip_min=1e-9):
63
+ v_start = math.cos(start*math.pi / 2) ** (2 * tau)
64
+ betas = []
65
+ v_end = math.cos(end* math.pi/2) ** 2*tau
66
+ for t in range(num_steps):
67
+ t_float = float(t)/num_steps
68
+ output = math.cos((t_float* (end-start)+start)*math.pi/2)**(2*tau)
69
+ output = (v_end - output) / (v_end-v_start)
70
+ betas.append(np.clip(output*.2,clip_min,.2))
71
+ return torch.flip(torch.tensor(betas).to(self.device),dims=[0]).float()
72
+
73
+
74
+ def sample_time_steps(self, batch_size=1):
75
+ return torch.randint(0, self.noise_steps, (batch_size,)).to(self.device)
76
+
77
+ def to(self,device):
78
+ self.device = device
79
+ self.betas = self.betas.to(device)
80
+ self.alphas = self.alphas.to(device)
81
+ self.alpha_hat = self.alpha_hat.to(device)
82
+
83
+
84
+ def q(self, x, t):
85
+ """
86
+ Forward process
87
+ """
88
+ pass
89
+
90
+ def p(self, x, t):
91
+ """
92
+ Backward process
93
+ """
94
+ pass
95
+
96
+
97
+ def apply_noise(self, x, t):
98
+ # force x to be (batch_size, image_width, image_height, channels)
99
+ if len(x.shape) == 3:
100
+ x = x.unsqueeze(0)
101
+ if type(t) == int:
102
+ t = torch.tensor([t])
103
+ #print(f'Shape -> {x.shape}, len -> {len(x.shape)}')
104
+ sqrt_alpha_hat = torch.sqrt(torch.tensor([self.alpha_hat[t_] for t_ in t]).to(self.device))
105
+ sqrt_one_minus_alpha_hat = torch.sqrt(torch.tensor([1.0 - self.alpha_hat[t_] for t_ in t]).to(self.device))
106
+ # standard normal distribution
107
+ epsilon = torch.randn_like(x).to(self.device)
108
+
109
+ # Eq 2. in DDPM paper
110
+ #noisy_image = sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon
111
+
112
+ """print(f'''
113
+ Shape of x {x.shape}
114
+ Shape of sqrt {sqrt_one_minus_alpha_hat.shape}''')"""
115
+
116
+ try:
117
+ #print(x.shape)
118
+ #noisy_image = torch.einsum("b,bwhc->bwhc", sqrt_alpha_hat, x.to(self.device)) + torch.einsum("b,bwhc->bwhc", sqrt_one_minus_alpha_hat, epsilon)
119
+ noisy_image = torch.einsum("b,bcwh->bcwh", sqrt_alpha_hat, x.to(self.device)) + torch.einsum("b,bcwh->bcwh", sqrt_one_minus_alpha_hat, epsilon)
120
+ except:
121
+ print(f'Failed image: shape {x.shape}')
122
+
123
+
124
+ #print(f'Noisy image -> {noisy_image.shape}')
125
+ # returning noisy iamge and the noise which was added to the image
126
+ #return noisy_image, epsilon
127
+ #return torch.clip(noisy_image, -1.0, 1.0), epsilon
128
+ return noisy_image, epsilon
129
+
130
+ @staticmethod
131
+ def normalize_image(x):
132
+ # normalize image to [-1, 1]
133
+ return x / 255.0 * 2.0 - 1.0
134
+
135
+ @staticmethod
136
+ def denormalize_image(x):
137
+ # denormalize image to [0, 255]
138
+ return (x + 1.0) / 2.0 * 255.0
139
+
140
+ def sample_step(self, x, t, cond):
141
+ batch_size = x.shape[0]
142
+ device = x.device
143
+ z = torch.randn_like(x) if t >= 1 else torch.zeros_like(x)
144
+ z = z.to(device)
145
+ alpha = self.alphas[t]
146
+ one_over_sqrt_alpha = 1.0 / torch.sqrt(alpha)
147
+ one_minus_alpha = 1.0 - alpha
148
+
149
+ sqrt_one_minus_alpha_hat = torch.sqrt(1.0 - self.alpha_hat[t])
150
+ beta_hat = (1 - self.alpha_hat[t-1]) / (1 - self.alpha_hat[t]) * self.betas[t]
151
+ beta = self.betas[t]
152
+ # should we reshape the params to (batch_size, 1, 1, 1) ?
153
+
154
+
155
+ # we can either use beta_hat or beta_t
156
+ # std = torch.sqrt(beta_hat)
157
+ std = torch.sqrt(beta)
158
+ # mean + variance * z
159
+ if cond is not None:
160
+ predicted_noise = self.model(x, torch.tensor([t]).repeat(batch_size).to(device), cond)
161
+ else:
162
+ predicted_noise = self.model(x, torch.tensor([t]).repeat(batch_size).to(device))
163
+ mean = one_over_sqrt_alpha * (x - one_minus_alpha / sqrt_one_minus_alpha_hat * predicted_noise)
164
+ x_t_minus_1 = mean + std * z
165
+
166
+ return x_t_minus_1
167
+
168
+ def sample(self, num_samples, show_progress=True):
169
+ """
170
+ Sample from the model
171
+ """
172
+ cond = None
173
+ if self.model.is_conditional:
174
+ # cond is arange()
175
+ assert num_samples <= self.model.num_classes, "num_samples must be less than or equal to the number of classes"
176
+ cond = torch.arange(self.model.num_classes)[:num_samples].to(self.device)
177
+ cond = rearrange(cond, 'i -> i ()')
178
+
179
+ self.model.eval()
180
+ image_versions = []
181
+ with torch.no_grad():
182
+ x = torch.randn(num_samples, self.channels, *self.image_size).to(self.device)
183
+ it = reversed(range(1, self.noise_steps))
184
+ if show_progress:
185
+ it = tqdm(it)
186
+ for t in it:
187
+ image_versions.append(self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0))
188
+ x = self.sample_step(x, t, cond)
189
+ self.model.train()
190
+ x = torch.clip(x, -1.0, 1.0)
191
+ return self.denormalize_image(x), image_versions
192
+
193
+ def validate(self, dataloader):
194
+ """
195
+ Calculate the loss on the validation set
196
+ """
197
+ self.model.eval()
198
+ acc_loss = 0
199
+ with torch.no_grad():
200
+ for (image, cond) in dataloader:
201
+ t = self.sample_time_steps(batch_size=image.shape[0])
202
+ noisy_image, added_noise = self.apply_noise(image, t)
203
+ noisy_image = noisy_image.to(self.device)
204
+ added_noise = added_noise.to(self.device)
205
+ cond = cond.to(self.device)
206
+ predicted_noise = self.model(noisy_image, t, cond)
207
+ loss = nn.MSELoss()(predicted_noise, added_noise)
208
+ acc_loss += loss.item()
209
+ self.model.train()
210
+ return acc_loss / len(dataloader)
211
+
212
+ class DiffusionImageAPI:
213
+ def __init__(self, diffusion_model):
214
+ self.diffusion_model = diffusion_model
215
+
216
+ def get_noisy_image(self, image, t):
217
+ x = torch.tensor(np.array(image))
218
+
219
+ x = self.diffusion_model.normalize_image(x)
220
+
221
+ y, _ = self.diffusion_model.apply_noise(x, t)
222
+
223
+ y = self.diffusion_model.denormalize_image(y)
224
+ #print(f"Shape of Image: {y.shape}")
225
+
226
+ return Image.fromarray(y.squeeze(0).numpy().astype(np.uint8))
227
+
228
+
229
+ def get_noisy_images(self, image, time_steps):
230
+ """
231
+ image: the image to be processed PIL.Image
232
+ time_steps: the number of time steps to apply noise (int)
233
+ """
234
+
235
+ return [self.get_noisy_image(image, int(t)) for t in time_steps]
236
+
237
+ def tensor_to_image(self, tensor):
238
+ return Image.fromarray(tensor.cpu().numpy().astype(np.uint8))
239
+
240
+
241
+
242
+
243
+
244
+
245
+
246
+
247
+
248
+ str_to_act = defaultdict(lambda: nn.SiLU())
249
+ str_to_act.update({
250
+ "relu": nn.ReLU(),
251
+ "silu": nn.SiLU(),
252
+ "gelu": nn.GELU(),
253
+ })
254
+
255
+ class SinusoidalPositionalEncoding(nn.Module):
256
+ def __init__(self, dim):
257
+ super().__init__()
258
+ self.dim = dim
259
+
260
+ def forward(self, t):
261
+ device = t.device
262
+ t = t.unsqueeze(-1)
263
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim))
264
+ sin_enc = torch.sin(t.repeat(1, self.dim // 2) * inv_freq)
265
+ cos_enc = torch.cos(t.repeat(1, self.dim // 2) * inv_freq)
266
+ pos_enc = torch.cat([sin_enc, cos_enc], dim=-1)
267
+ return pos_enc
268
+
269
+ class TimeEmbedding(nn.Module):
270
+ def __init__(self, model_dim: int, emb_dim: int, act="silu"):
271
+ super().__init__()
272
+
273
+ self.lin = nn.Linear(model_dim, emb_dim)
274
+ self.act = str_to_act[act]
275
+ self.lin2 = nn.Linear(emb_dim, emb_dim)
276
+
277
+ def forward(self, x):
278
+ x = self.lin(x)
279
+ x = self.act(x)
280
+ x = self.lin2(x)
281
+ return x
282
+
283
+ class ConvBlock(nn.Module):
284
+ def __init__(self, in_channels, out_channels, act="silu", dropout=None, zero=False):
285
+ super().__init__()
286
+
287
+ self.norm = nn.GroupNorm(
288
+ num_groups=32,
289
+ num_channels=in_channels,
290
+ )
291
+
292
+ self.act = str_to_act[act]
293
+
294
+ if dropout is not None:
295
+ self.dropout = nn.Dropout(dropout)
296
+
297
+ self.conv = nn.Conv2d(
298
+ in_channels=in_channels,
299
+ out_channels=out_channels,
300
+ kernel_size=3,
301
+ padding=1,
302
+ )
303
+ if zero:
304
+ self.conv.weight.data.zero_()
305
+
306
+
307
+ def forward(self, x):
308
+ x = self.norm(x)
309
+ x = self.act(x)
310
+ if hasattr(self, "dropout"):
311
+ x = self.dropout(x)
312
+ x = self.conv(x)
313
+ return x
314
+
315
+ class EmbeddingBlock(nn.Module):
316
+ def __init__(self, channels: int, emb_dim: int, act="silu"):
317
+ super().__init__()
318
+
319
+ self.act = str_to_act[act]
320
+ self.lin = nn.Linear(emb_dim, channels)
321
+
322
+ def forward(self, x):
323
+ x = self.act(x)
324
+ x = self.lin(x)
325
+ return x
326
+
327
+ class ResBlock(nn.Module):
328
+ def __init__(self, channels: int, emb_dim: int, dropout: float = 0, out_channels=None):
329
+ """A resblock with a time embedding and an optional change in channel count
330
+ """
331
+ if out_channels is None:
332
+ out_channels = channels
333
+ super().__init__()
334
+
335
+ self.conv1 = ConvBlock(channels, out_channels)
336
+
337
+ self.emb = EmbeddingBlock(out_channels, emb_dim)
338
+
339
+ self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout, zero=True)
340
+
341
+ if channels != out_channels:
342
+ self.skip_connection = nn.Conv2d(channels, out_channels, kernel_size=1)
343
+ else:
344
+ self.skip_connection = nn.Identity()
345
+
346
+
347
+ def forward(self, x, t):
348
+ original = x
349
+ x = self.conv1(x)
350
+
351
+ t = self.emb(t)
352
+ # t: (batch_size, time_embedding_dim) = (batch_size, out_channels)
353
+ # x: (batch_size, out_channels, height, width)
354
+ # we repeat the time embedding to match the shape of x
355
+ t = t.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, x.shape[2], x.shape[3])
356
+
357
+ x = x + t
358
+
359
+ x = self.conv2(x)
360
+ x = x + self.skip_connection(original)
361
+ return x
362
+
363
+ class SelfAttentionBlock(nn.Module):
364
+ def __init__(self, channels, num_heads=1):
365
+ super().__init__()
366
+ self.channels = channels
367
+ self.num_heads = num_heads
368
+
369
+ self.norm = nn.GroupNorm(32, channels)
370
+
371
+ self.attention = nn.MultiheadAttention(
372
+ embed_dim=channels,
373
+ num_heads=num_heads,
374
+ dropout=0,
375
+ batch_first=True,
376
+ bias=True,
377
+ )
378
+
379
+ def forward(self, x):
380
+ h, w = x.shape[-2:]
381
+ original = x
382
+ x = self.norm(x)
383
+ x = rearrange(x, "b c h w -> b (h w) c")
384
+ x = self.attention(x, x, x)[0]
385
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
386
+ return x + original
387
+
388
+ class Downsample(nn.Module):
389
+ def __init__(self, channels):
390
+ super().__init__()
391
+ # ddpm uses maxpool
392
+ # self.down = nn.MaxPool2d
393
+
394
+ # iddpm uses strided conv
395
+ self.down = nn.Conv2d(
396
+ in_channels=channels,
397
+ out_channels=channels,
398
+ kernel_size=3,
399
+ stride=2,
400
+ padding=1,
401
+ )
402
+
403
+ def forward(self, x):
404
+ return self.down(x)
405
+
406
+ class DownBlock(nn.Module):
407
+ """According to U-Net paper
408
+
409
+ 'The contracting path follows the typical architecture of a convolutional network.
410
+ It consists of the repeated application of two 3x3 convolutions (unpadded convolutions),
411
+ each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2
412
+ for downsampling. At each downsampling step we double the number of feature channels.'
413
+ """
414
+
415
+ def __init__(self, in_channels, out_channels, time_embedding_dim, use_attn=False, dropout=0, downsample=True, width=1):
416
+ """in_channels will typically be half of out_channels"""
417
+ super().__init__()
418
+ self.width = width
419
+ self.use_attn = use_attn
420
+ self.do_downsample = downsample
421
+
422
+ self.blocks = nn.ModuleList()
423
+ for _ in range(width):
424
+ self.blocks.append(ResBlock(
425
+ channels=in_channels,
426
+ out_channels=out_channels,
427
+ emb_dim=time_embedding_dim,
428
+ dropout=dropout,
429
+ ))
430
+ if self.use_attn:
431
+ self.blocks.append(SelfAttentionBlock(
432
+ channels=out_channels,
433
+ ))
434
+ in_channels = out_channels
435
+
436
+ if self.do_downsample:
437
+ self.downsample = Downsample(out_channels)
438
+
439
+ def forward(self, x, t):
440
+ for block in self.blocks:
441
+ if isinstance(block, ResBlock):
442
+ x = block(x, t)
443
+ elif isinstance(block, SelfAttentionBlock):
444
+ x = block(x)
445
+
446
+ residual = x
447
+ if self.do_downsample:
448
+ x = self.downsample(x)
449
+ return x, residual
450
+
451
+ class Upsample(nn.Module):
452
+ def __init__(self, channels):
453
+ super().__init__()
454
+ self.upsample = nn.Upsample(scale_factor=2)
455
+ self.conv = nn.Conv2d(
456
+ in_channels=channels,
457
+ out_channels=channels,
458
+ kernel_size=3,
459
+ padding=1,
460
+ )
461
+
462
+ def forward(self, x):
463
+ x = self.upsample(x)
464
+ x = self.conv(x)
465
+ return x
466
+
467
+ class UpBlock(nn.Module):
468
+ """According to U-Net paper
469
+
470
+ Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2
471
+ convolution (“up-convolution”) that halves the number of feature channels, a concatenation with
472
+ the correspondingly cropped feature map from the contracting path, and two 3x3 convolutions,
473
+ each followed by a ReLU.
474
+ """
475
+
476
+ def __init__(self, in_channels, out_channels, time_embedding_dim, use_attn=False, dropout=0, upsample=True, width=1):
477
+ """in_channels will typically be double of out_channels
478
+ """
479
+ super().__init__()
480
+ self.use_attn = use_attn
481
+ self.do_upsample = upsample
482
+
483
+ self.blocks = nn.ModuleList()
484
+ for _ in range(width):
485
+ self.blocks.append(ResBlock(
486
+ channels=in_channels,
487
+ out_channels=out_channels,
488
+ emb_dim=time_embedding_dim,
489
+ dropout=dropout,
490
+ ))
491
+ if self.use_attn:
492
+ self.blocks.append(SelfAttentionBlock(
493
+ channels=out_channels,
494
+ ))
495
+ in_channels = out_channels
496
+
497
+ if self.do_upsample:
498
+ self.upsample = Upsample(out_channels)
499
+
500
+ def forward(self, x, t):
501
+ for block in self.blocks:
502
+ if isinstance(block, ResBlock):
503
+ x = block(x, t)
504
+ elif isinstance(block, SelfAttentionBlock):
505
+ x = block(x)
506
+
507
+ if self.do_upsample:
508
+ x = self.upsample(x)
509
+ return x
510
+
511
+ class Bottleneck(nn.Module):
512
+ def __init__(self, channels, dropout, time_embedding_dim):
513
+ super().__init__()
514
+ in_channels = channels
515
+ out_channels = channels
516
+ self.resblock_1 = ResBlock(
517
+ channels=in_channels,
518
+ out_channels=out_channels,
519
+ dropout=dropout,
520
+ emb_dim=time_embedding_dim
521
+ )
522
+ self.attention_block = SelfAttentionBlock(
523
+ channels=out_channels,
524
+ )
525
+ self.resblock_2 = ResBlock(
526
+ channels=out_channels,
527
+ out_channels=out_channels,
528
+ dropout=dropout,
529
+ emb_dim=time_embedding_dim
530
+ )
531
+
532
+ def forward(self, x, t):
533
+ x = self.resblock_1(x, t)
534
+ x = self.attention_block(x)
535
+ x = self.resblock_2(x, t)
536
+ return x
537
+
538
+ class Unet(nn.Module):
539
+ def __init__(
540
+ self,
541
+ image_channels=3,
542
+ res_block_width=2,
543
+ starting_channels=128,
544
+ dropout=0,
545
+ channel_mults=(1, 2, 2, 4, 4),
546
+ attention_layers=(False, False, False, True, False)
547
+ ):
548
+ super().__init__()
549
+ self.is_conditional = False
550
+
551
+ self.image_channels = image_channels
552
+ self.starting_channels = starting_channels
553
+ time_embedding_dim = 4 * starting_channels
554
+
555
+ self.time_encoding = SinusoidalPositionalEncoding(dim=starting_channels)
556
+ self.time_embedding = TimeEmbedding(model_dim=starting_channels, emb_dim=time_embedding_dim)
557
+
558
+ self.input = nn.Conv2d(3, starting_channels, kernel_size=3, padding=1)
559
+
560
+ current_channel_count = starting_channels
561
+
562
+ input_channel_counts = []
563
+ self.contracting_path = nn.ModuleList([])
564
+ for i, channel_multiplier in enumerate(channel_mults):
565
+ is_last_layer = i == len(channel_mults) - 1
566
+ next_channel_count = channel_multiplier * starting_channels
567
+
568
+ self.contracting_path.append(DownBlock(
569
+ in_channels=current_channel_count,
570
+ out_channels=next_channel_count,
571
+ time_embedding_dim=time_embedding_dim,
572
+ use_attn=attention_layers[i],
573
+ dropout=dropout,
574
+ downsample=not is_last_layer,
575
+ width=res_block_width,
576
+ ))
577
+ current_channel_count = next_channel_count
578
+
579
+ input_channel_counts.append(current_channel_count)
580
+
581
+ self.bottleneck = Bottleneck(channels=current_channel_count, time_embedding_dim=time_embedding_dim, dropout=dropout)
582
+
583
+ self.expansive_path = nn.ModuleList([])
584
+ for i, channel_multiplier in enumerate(reversed(channel_mults)):
585
+ next_channel_count = channel_multiplier * starting_channels
586
+
587
+ self.expansive_path.append(UpBlock(
588
+ in_channels=current_channel_count + input_channel_counts.pop(),
589
+ out_channels=next_channel_count,
590
+ time_embedding_dim=time_embedding_dim,
591
+ use_attn=list(reversed(attention_layers))[i],
592
+ dropout=dropout,
593
+ upsample=i != len(channel_mults) - 1,
594
+ width=res_block_width,
595
+ ))
596
+ current_channel_count = next_channel_count
597
+
598
+ last_conv = nn.Conv2d(
599
+ in_channels=starting_channels,
600
+ out_channels=image_channels,
601
+ kernel_size=3,
602
+ padding=1,
603
+ )
604
+ last_conv.weight.data.zero_()
605
+
606
+ self.head = nn.Sequential(
607
+ nn.GroupNorm(32, starting_channels),
608
+ nn.SiLU(),
609
+ last_conv,
610
+ )
611
+
612
+ def forward(self, x, t):
613
+ t = self.time_encoding(t)
614
+ return self._forward(x, t)
615
+
616
+ def _forward(self, x, t):
617
+ t = self.time_embedding(t)
618
+
619
+ x = self.input(x)
620
+
621
+ residuals = []
622
+ for contracting_block in self.contracting_path:
623
+ x, residual = contracting_block(x, t)
624
+ residuals.append(residual)
625
+
626
+ x = self.bottleneck(x, t)
627
+
628
+ for expansive_block in self.expansive_path:
629
+ # Add the residual
630
+ residual = residuals.pop()
631
+ x = torch.cat([x, residual], dim=1)
632
+
633
+ x = expansive_block(x, t)
634
+
635
+ x = self.head(x)
636
+ return x
637
+
638
+ class ConditionalUnet(nn.Module):
639
+ def __init__(self, unet, num_classes):
640
+ super().__init__()
641
+ self.is_conditional = True
642
+
643
+ self.unet = unet
644
+ self.num_classes = num_classes
645
+
646
+ self.class_embedding = nn.Embedding(num_classes, unet.starting_channels)
647
+
648
+ def forward(self, x, t, cond=None):
649
+ # cond: (batch_size, n), where n is the number of classes that we are conditioning on
650
+ t = self.unet.time_encoding(t)
651
+
652
+ if cond is not None:
653
+ cond = self.class_embedding(cond)
654
+ # sum across the classes so we get a single vector representing the set of classes
655
+ cond = cond.sum(dim=1)
656
+ t += cond
657
+
658
+ return self.unet._forward(x, t)
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  streamlit
2
  torch
3
  torchvision
4
- numpy
 
 
 
 
1
  streamlit
2
  torch
3
  torchvision
4
+ numpy
5
+ einops
6
+ pillow
7
+ tqdm