pawlo2013 commited on
Commit
0b2b0ab
·
1 Parent(s): 27a598c

init commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.st filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #.model.pth
2
+
3
+
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+
9
+ # vim swp files
10
+ *.swp
11
+ # caffe/pytorch model files
12
+ *.pth
13
+
14
+ *.pt
15
+ # json
16
+ *.json
17
+
18
+ *.bin
19
+
20
+ *.st
21
+
22
+ .models/model-epoch_80.st
23
+ .history/
24
+
25
+ dataset/
26
+
27
+ wandb/
28
+
29
+
30
+ .vscode/
31
+ https://github.com/higumax/sketchKeras-pytorch.git
32
+
33
+ .startup.sh
34
+
35
+ startup.sh
36
+
37
+
38
+
39
+
40
+
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ from torchvision import transforms
5
+ from load_model import sample
6
+ import torch
7
+ import random
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ device = "mps" if torch.backends.mps.is_available() else device
11
+
12
+ image_size = 128
13
+ upscale = False
14
+ clicked = False
15
+
16
+
17
+ transform = transforms.Compose(
18
+ [
19
+ transforms.Resize((image_size, image_size)),
20
+ transforms.ToTensor(),
21
+ transforms.Lambda(lambda t: (t * 2) - 1),
22
+ ]
23
+ )
24
+
25
+
26
+ def make_scribbles(sketch, scribbles):
27
+ # get the value that occurs most often in the scribbles
28
+ sketch = transforms.Resize((image_size, image_size))(sketch)
29
+ scribbles = transforms.Resize((image_size, image_size))(scribbles)
30
+
31
+ grey_tensor = torch.tensor(0.49803922, device=device)
32
+
33
+ grey_tensor = grey_tensor.expand(3, image_size, image_size)
34
+
35
+ sketch = transforms.ToTensor()(sketch).to(device)
36
+ scribbles = transforms.ToTensor()(scribbles).to(device)
37
+
38
+ scribble_where_grey_mask = torch.eq(scribbles, grey_tensor)
39
+
40
+ merged = torch.where(scribble_where_grey_mask, sketch, scribbles)
41
+
42
+ return transforms.Lambda(lambda t: (t * 2) - 1)(sketch), transforms.Lambda(
43
+ lambda t: (t * 2) - 1
44
+ )(merged)
45
+
46
+
47
+ def process_images(sketch, scribbles, sampling_steps, is_scribbles, seed_nr, upscale):
48
+ global clicked
49
+ clicked = True
50
+ w, h = sketch.size
51
+
52
+ if is_scribbles:
53
+ sketch, scribbles = make_scribbles(sketch, scribbles)
54
+
55
+ else:
56
+ sketch = transform(sketch.convert("RGB"))
57
+ scribbles = transform(scribbles.convert("RGB"))
58
+
59
+ if upscale:
60
+ output = transforms.Resize((h, w))(
61
+ sample(sketch, scribbles, sampling_steps, seed_nr)
62
+ )
63
+ clicked = False
64
+ return output
65
+ else:
66
+ output = sample(sketch, scribbles, sampling_steps, seed_nr)
67
+ clicked = False
68
+ return output
69
+
70
+
71
+ theme = gr.themes.Monochrome()
72
+
73
+
74
+ with gr.Blocks(theme=theme) as demo:
75
+ with gr.Row():
76
+ gr.Markdown(
77
+ "<h1 style='text-align: center; font-size: 30px;'>Image Inpainting with Conditional Diffusion by MedicAI</h1>"
78
+ )
79
+
80
+ with gr.Row():
81
+ with gr.Column():
82
+ sketch_input = gr.Image(type="pil", label="Sketch", height=500)
83
+ with gr.Column():
84
+ scribbles_input = gr.Image(type="pil", label="Scribbles", height=500)
85
+ info = gr.Markdown(
86
+ "<p style='text-align: center; font-size: 12px;'>"
87
+ "By default the scribbles are assumed to be merged with the sketch, if they appear on a grey background check the box below. "
88
+ "</p>"
89
+ )
90
+ is_scribbles = gr.Checkbox(label="Is Scribbles", value=False)
91
+ with gr.Column():
92
+ output = gr.Image(type="pil", label="Output")
93
+ upscale_info = gr.Markdown(
94
+ "<p style='text-align: center; font-size: 12px;'>"
95
+ f"If you want to stretch the downloadable output, check the box below, the default output of neural networks is {image_size}x{image_size} "
96
+ "</p>"
97
+ )
98
+ upscale_button = gr.Checkbox(label="Stretch", value=False)
99
+ with gr.Row():
100
+ with gr.Column():
101
+ seed_slider = gr.Number(
102
+ label="Random Seed 🎲",
103
+ value=random.randint(
104
+ 1,
105
+ 1000,
106
+ ),
107
+ )
108
+
109
+ with gr.Column():
110
+ sampling_slider = gr.Slider(
111
+ minimum=1,
112
+ maximum=250,
113
+ step=1,
114
+ label="DDPM Sampling Steps 🔄",
115
+ value=50,
116
+ )
117
+
118
+ with gr.Row():
119
+ generate_button = gr.Button(value="Generate", interactive=not clicked)
120
+
121
+ generate_button.click(
122
+ process_images,
123
+ inputs=[
124
+ sketch_input,
125
+ scribbles_input,
126
+ sampling_slider,
127
+ is_scribbles,
128
+ seed_slider,
129
+ upscale_button,
130
+ ],
131
+ outputs=output,
132
+ show_progress=True,
133
+ )
134
+
135
+
136
+ demo.launch(server_port=3000, max_threads=1)
load_model.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.structure.Unet_3 import Unet
2
+ from diffusers import DDPMScheduler
3
+ import torch
4
+ import os
5
+ import glob
6
+ from tqdm import tqdm
7
+ from torchvision import transforms
8
+ import pathlib
9
+ from torchvision.utils import save_image
10
+ from safetensors.torch import load_model, save_model
11
+
12
+
13
+ denoising_timesteps = 4000
14
+ image_size = 128
15
+ channels = 3
16
+
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ device = "mps" if torch.backends.mps.is_available() else device
20
+
21
+ model = Unet(
22
+ dim=image_size,
23
+ channels=channels,
24
+ dim_mults=(1, 2, 4, 8),
25
+ use_convnext=False,
26
+ ).to(device)
27
+
28
+ results_folder = pathlib.Path("models")
29
+
30
+
31
+ checkpoint_files_st = glob.glob(str(results_folder / "model-epoch_*.st"))
32
+ checkpoint_files_pt = glob.glob(str(results_folder / "model-epoch_*.pt"))
33
+
34
+ if checkpoint_files_st:
35
+ # Sort the list of matching files by modification time (newest first)
36
+ checkpoint_files_st.sort(key=lambda x: os.path.getmtime(x), reverse=True)
37
+ # Select the newest file
38
+ checkpoint_files = checkpoint_files_st[0]
39
+ # Now, newest_model_file contains the path to the newest "model" file
40
+ load_model(model, checkpoint_files)
41
+ model.eval()
42
+ print("Loaded model from checkpoint", checkpoint_files)
43
+
44
+ elif checkpoint_files_pt:
45
+ # Sort the list of matching files by modification time (newest first)
46
+ checkpoint_files_pt.sort(key=lambda x: os.path.getmtime(x), reverse=True)
47
+ # Select the newest file
48
+ checkpoint_files = checkpoint_files_pt[0]
49
+ # Now, newest_model_file contains the path to the newest "model" file
50
+ checkpoint = torch.load(checkpoint_files, map_location=device)
51
+ model.load_state_dict(checkpoint["model_state_dict"])
52
+ epoch = checkpoint["epoch"]
53
+ model.eval()
54
+ print("Loaded model from checkpoint", checkpoint_files)
55
+
56
+ if not pathlib.Path(str(results_folder / "model-epoch_*.st")).exists():
57
+ save_model(model, results_folder / "model-epoch_{}.st".format(epoch))
58
+ print("Saved model as a safetensor", results_folder)
59
+
60
+ else:
61
+ raise Exception("No model files found in the folder.")
62
+
63
+
64
+ def sample(sketch, scribbles, sampling_steps, seed_nr):
65
+ torch.manual_seed(seed_nr)
66
+
67
+ noise_scheduler = DDPMScheduler(
68
+ num_train_timesteps=denoising_timesteps, beta_schedule="squaredcos_cap_v2"
69
+ )
70
+ noise_scheduler.set_timesteps(sampling_steps, device=device)
71
+
72
+ sketch = sketch.to(device)
73
+ scribbles = scribbles.to(device)
74
+
75
+ sketch = sketch.unsqueeze(0)
76
+ scribbles = scribbles.unsqueeze(0)
77
+
78
+ with torch.no_grad():
79
+ b = sketch.shape[0]
80
+
81
+ noise_for_plain = torch.randn_like(sketch, device=device)
82
+
83
+ for i, t in tqdm(
84
+ enumerate(noise_scheduler.timesteps),
85
+ total=len(noise_scheduler.timesteps),
86
+ ):
87
+ noise_for_plain = noise_scheduler.scale_model_input(noise_for_plain, t).to(
88
+ device
89
+ )
90
+
91
+ time = t.expand(
92
+ b,
93
+ ).to(device)
94
+
95
+ plain_noise_pred = model(
96
+ x=noise_for_plain,
97
+ time=time,
98
+ implicit_conditioning=scribbles,
99
+ explicit_conditioning=sketch,
100
+ )
101
+
102
+ noise_for_plain = noise_scheduler.step(
103
+ plain_noise_pred,
104
+ t.long(),
105
+ noise_for_plain,
106
+ ).prev_sample
107
+
108
+ sample = torch.clamp((noise_for_plain / 2) + 0.5, 0, 1)
109
+
110
+ return transforms.ToPILImage()(sample[0].cpu())
models/structure/Advanced_Network_Helpers.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ from functools import partial
4
+ import matplotlib.pyplot as plt
5
+ from tqdm.auto import tqdm
6
+ from einops import rearrange
7
+ import torch
8
+ from torch import nn, einsum
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def exists(x):
13
+ return x is not None
14
+
15
+
16
+ def default(val, d):
17
+ if exists(val):
18
+ return val
19
+ return d() if isfunction(d) else d
20
+
21
+
22
+ class Residual(nn.Module):
23
+ def __init__(self, fn):
24
+ super().__init__()
25
+ self.fn = fn
26
+
27
+ def forward(self, x, *args, **kwargs):
28
+ return self.fn(x, *args, **kwargs) + x
29
+
30
+
31
+ def Upsample(dim):
32
+ return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
33
+
34
+
35
+ def Downsample(dim):
36
+ return nn.Conv2d(dim, dim, 4, 2, 1)
37
+
38
+
39
+ class SinusoidalPositionEmbeddings(nn.Module):
40
+ def __init__(self, dim):
41
+ super().__init__()
42
+ self.dim = dim
43
+
44
+ def forward(self, time):
45
+ device = time.device
46
+ half_dim = self.dim // 2
47
+ embeddings = math.log(10000) / (half_dim - 1)
48
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
49
+ embeddings = time[:, None] * embeddings[None, :]
50
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
51
+ return embeddings
52
+
53
+
54
+ class Block(nn.Module):
55
+ def __init__(self, dim, dim_out, groups=8):
56
+ super().__init__()
57
+ self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
58
+ self.norm = nn.GroupNorm(groups, dim_out)
59
+ self.act = nn.SiLU()
60
+
61
+ def forward(self, x, scale_shift=None):
62
+ x = self.proj(x)
63
+ x = self.norm(x)
64
+
65
+ if exists(scale_shift):
66
+ scale, shift = scale_shift
67
+ x = x * (scale + 1) + shift
68
+
69
+ x = self.act(x)
70
+ return x
71
+
72
+
73
+ class ResnetBlock(nn.Module):
74
+ """https://arxiv.org/abs/1512.03385"""
75
+
76
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
77
+ super().__init__()
78
+ self.mlp = (
79
+ nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
80
+ if exists(time_emb_dim)
81
+ else None
82
+ )
83
+
84
+ self.block1 = Block(dim, dim_out, groups=groups)
85
+ self.block2 = Block(dim_out, dim_out, groups=groups)
86
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
87
+
88
+ def forward(self, x, time_emb=None):
89
+ h = self.block1(x)
90
+
91
+ if exists(self.mlp) and exists(time_emb):
92
+ time_emb = self.mlp(time_emb)
93
+ h = rearrange(time_emb, "b c -> b c 1 1") + h
94
+
95
+ h = self.block2(h)
96
+ return h + self.res_conv(x)
97
+
98
+
99
+ class ConvNextBlock(nn.Module):
100
+ """https://arxiv.org/abs/2201.03545"""
101
+
102
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
103
+ super().__init__()
104
+ self.mlp = (
105
+ nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
106
+ if exists(time_emb_dim)
107
+ else None
108
+ )
109
+
110
+ self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
111
+
112
+ self.net = nn.Sequential(
113
+ nn.GroupNorm(1, dim) if norm else nn.Identity(),
114
+ nn.Conv2d(dim, dim_out * mult, 3, padding=1),
115
+ nn.GELU(),
116
+ nn.GroupNorm(1, dim_out * mult),
117
+ nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
118
+ )
119
+
120
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
121
+
122
+ def forward(self, x, time_emb=None):
123
+ h = self.ds_conv(x)
124
+
125
+ if exists(self.mlp) and exists(time_emb):
126
+ assert exists(time_emb), "time embedding must be passed in"
127
+ condition = self.mlp(time_emb)
128
+ h = h + rearrange(condition, "b c -> b c 1 1")
129
+
130
+ h = self.net(h)
131
+ return h + self.res_conv(x)
132
+
133
+
134
+ class Attention(nn.Module):
135
+ def __init__(self, dim, heads=4, dim_head=32):
136
+ super().__init__()
137
+ self.scale = dim_head**-0.5
138
+ self.heads = heads
139
+ hidden_dim = dim_head * heads
140
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
141
+ self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
142
+ self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
143
+ self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
144
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
145
+
146
+ def forward(self, x, cross_attend=None):
147
+ b, c, h, w = x.shape
148
+
149
+ if cross_attend is not None:
150
+ assert cross_attend.shape == x.shape
151
+
152
+ q_att = self.to_q(x)
153
+ k_att = self.to_k(cross_attend)
154
+ v_att = self.to_v(cross_attend)
155
+ q = rearrange(q_att, "b (h c) x y -> b h c (x y)", h=self.heads)
156
+ k = rearrange(k_att, "b (h c) x y -> b h c (x y)", h=self.heads)
157
+ v = rearrange(v_att, "b (h c) x y -> b h c (x y)", h=self.heads)
158
+ else:
159
+ qkv = self.to_qkv(x).chunk(3, dim=1)
160
+ q, k, v = map(
161
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
162
+ )
163
+ q = q * self.scale
164
+
165
+ sim = einsum("b h d i, b h d j -> b h i j", q, k)
166
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
167
+ attn = sim.softmax(dim=-1)
168
+
169
+ out = einsum("b h i j, b h d j -> b h i d", attn, v)
170
+ out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
171
+
172
+ return self.to_out(out)
173
+
174
+
175
+ class LinearCrossAttention(nn.Module):
176
+ def __init__(self, dim, heads=12, dim_head=128) -> None:
177
+ super().__init__()
178
+ self.scale = dim_head**-0.5
179
+ self.heads = heads
180
+ hidden_dim = dim_head * heads
181
+ self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias=False)
182
+ self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
183
+ self.out = nn.Conv2d(hidden_dim, dim, 1)
184
+
185
+ def forward(self, x, cross_attend):
186
+ b, c, h, w = x.shape
187
+ q = self.to_q(x)
188
+ k, v = self.to_kv(cross_attend).chunk(2, dim=1)
189
+ q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads)
190
+ k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads)
191
+ v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads)
192
+ q = q * self.scale
193
+ sim = einsum("b h d i, b h d j -> b h i j", q, k)
194
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
195
+ attn = sim.softmax(dim=-1)
196
+ out = einsum("b h i j, b h d j -> b h i d", attn, v)
197
+ out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
198
+ return self.out(out)
199
+
200
+
201
+ class LinearAttention(nn.Module):
202
+ def __init__(self, dim, heads=4, dim_head=32):
203
+ super().__init__()
204
+ self.scale = dim_head**-0.5
205
+ self.heads = heads
206
+ hidden_dim = dim_head * heads
207
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
208
+ self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
209
+ self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
210
+ self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
211
+ self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
212
+
213
+ def forward(self, x, cross_attend=None):
214
+ b, c, h, w = x.shape
215
+ if cross_attend is not None:
216
+ assert (
217
+ cross_attend.shape == x.shape
218
+ ), f"cross_attend must be same shape as x is {cross_attend.shape} and x is {x.shape}"
219
+
220
+ q_att = self.to_q(x)
221
+ k_att = self.to_k(cross_attend)
222
+ v_att = self.to_v(cross_attend)
223
+ q = rearrange(q_att, "b (h c) x y -> b h c (x y)", h=self.heads)
224
+ k = rearrange(k_att, "b (h c) x y -> b h c (x y)", h=self.heads)
225
+ v = rearrange(v_att, "b (h c) x y -> b h c (x y)", h=self.heads)
226
+
227
+ else:
228
+ qkv = self.to_qkv(x).chunk(3, dim=1)
229
+ q, k, v = map(
230
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
231
+ )
232
+ # calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
233
+ q = q.softmax(dim=-2)
234
+ # calculate the softmax with respect to rows of k
235
+ k = k.softmax(dim=-1)
236
+ # normalize the values in the attention matrix
237
+ q = q * self.scale
238
+ # dot product of q and v matrices
239
+ context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
240
+ # dot product of context and q
241
+ out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
242
+ # rearrange the output to match the pytorch convention
243
+ out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
244
+ return self.to_out(out)
245
+
246
+
247
+ class PreNorm(nn.Module):
248
+ def __init__(self, dim, fn):
249
+ super().__init__()
250
+ self.fn = fn
251
+ self.norm = nn.GroupNorm(1, dim)
252
+
253
+ def forward(self, x, *args, **kwargs):
254
+ x = self.norm(x)
255
+ return self.fn(x, *args, **kwargs)
models/structure/Advanced_Network_Helpers_2.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ from functools import partial
4
+ import matplotlib.pyplot as plt
5
+ from tqdm.auto import tqdm
6
+ from einops import rearrange
7
+ import torch
8
+ from torch import nn, einsum
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def exists(x):
13
+ return x is not None
14
+
15
+
16
+ def default(val, d):
17
+ if exists(val):
18
+ return val
19
+ return d() if isfunction(d) else d
20
+
21
+
22
+ class Residual(nn.Module):
23
+ def __init__(self, fn):
24
+ super().__init__()
25
+ self.fn = fn
26
+
27
+ def forward(self, x, *args, **kwargs):
28
+ return self.fn(x, *args, **kwargs) + x
29
+
30
+
31
+ def Upsample(dim):
32
+ return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
33
+
34
+
35
+ def Downsample(dim):
36
+ return nn.Conv2d(dim, dim, 4, 2, 1)
37
+
38
+
39
+ class SinusoidalPositionEmbeddings(nn.Module):
40
+ def __init__(self, dim):
41
+ super().__init__()
42
+ self.dim = dim
43
+
44
+ def forward(self, time):
45
+ device = time.device
46
+ half_dim = self.dim // 2
47
+ embeddings = math.log(10000) / (half_dim - 1)
48
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
49
+ embeddings = time[:, None] * embeddings[None, :]
50
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
51
+ return embeddings
52
+
53
+
54
+ class Block(nn.Module):
55
+ def __init__(self, dim, dim_out, groups=8):
56
+ super().__init__()
57
+ self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
58
+ self.norm = nn.GroupNorm(groups, dim_out)
59
+ self.act = nn.SiLU()
60
+
61
+ def forward(self, x, scale_shift=None):
62
+ x = self.proj(x)
63
+ x = self.norm(x)
64
+
65
+ if exists(scale_shift):
66
+ scale, shift = scale_shift
67
+ x = x * (scale + 1) + shift
68
+
69
+ x = self.act(x)
70
+ return x
71
+
72
+
73
+ class ResnetBlock(nn.Module):
74
+ """https://arxiv.org/abs/1512.03385"""
75
+
76
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
77
+ super().__init__()
78
+ self.mlp = (
79
+ nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
80
+ if exists(time_emb_dim)
81
+ else None
82
+ )
83
+
84
+ self.block1 = Block(dim, dim_out, groups=groups)
85
+ self.block2 = Block(dim_out, dim_out, groups=groups)
86
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
87
+
88
+ def forward(self, x, time_emb=None):
89
+ h = self.block1(x)
90
+
91
+ if exists(self.mlp) and exists(time_emb):
92
+ time_emb = self.mlp(time_emb)
93
+ h = rearrange(time_emb, "b c -> b c 1 1") + h
94
+
95
+ h = self.block2(h)
96
+ return h + self.res_conv(x)
97
+
98
+
99
+ class ConvNextBlock(nn.Module):
100
+ """https://arxiv.org/abs/2201.03545"""
101
+
102
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
103
+ super().__init__()
104
+ self.mlp = (
105
+ nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
106
+ if exists(time_emb_dim)
107
+ else None
108
+ )
109
+
110
+ self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
111
+
112
+ self.net = nn.Sequential(
113
+ nn.GroupNorm(1, dim) if norm else nn.Identity(),
114
+ nn.Conv2d(dim, dim_out * mult, 3, padding=1),
115
+ nn.GELU(),
116
+ nn.GroupNorm(1, dim_out * mult),
117
+ nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
118
+ )
119
+
120
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
121
+
122
+ def forward(self, x, time_emb=None):
123
+ h = self.ds_conv(x)
124
+
125
+ if exists(self.mlp) and exists(time_emb):
126
+ assert exists(time_emb), "time embedding must be passed in"
127
+ condition = self.mlp(time_emb)
128
+ h = h + rearrange(condition, "b c -> b c 1 1")
129
+
130
+ h = self.net(h)
131
+ return h + self.res_conv(x)
132
+
133
+
134
+ class Attention(nn.Module):
135
+ def __init__(self, dim, heads=4, dim_head=32):
136
+ super().__init__()
137
+ self.scale = dim_head**-0.5
138
+ self.heads = heads
139
+ hidden_dim = dim_head * heads
140
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
141
+ self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
142
+ self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
143
+ self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
144
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
145
+
146
+ def forward(self, x):
147
+ b, c, h, w = x.shape
148
+
149
+ qkv = self.to_qkv(x).chunk(3, dim=1)
150
+ q, k, v = map(
151
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
152
+ )
153
+ q = q * self.scale
154
+
155
+ sim = einsum("b h d i, b h d j -> b h i j", q, k)
156
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
157
+ attn = sim.softmax(dim=-1)
158
+
159
+ out = einsum("b h i j, b h d j -> b h i d", attn, v)
160
+ out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
161
+
162
+ return self.to_out(out)
163
+
164
+
165
+ class LinearCrossAttention(nn.Module):
166
+ def __init__(self, dim, heads=4, dim_head=32) -> None:
167
+ super().__init__()
168
+ self.scale = dim_head**-0.5
169
+ self.heads = heads
170
+ hidden_dim = dim_head * heads
171
+ self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias=False)
172
+ self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
173
+ self.out = nn.Conv2d(hidden_dim, dim, 1)
174
+
175
+ def forward(self, x, cross_attend):
176
+ b, c, h, w = x.shape
177
+ q = self.to_q(x)
178
+ k, v = self.to_kv(cross_attend).chunk(2, dim=1)
179
+ q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads)
180
+ k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads)
181
+ v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads)
182
+ q = q * self.scale
183
+ sim = einsum("b h d i, b h d j -> b h i j", q, k)
184
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
185
+ attn = sim.softmax(dim=-1)
186
+ out = einsum("b h i j, b h d j -> b h i d", attn, v)
187
+ out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
188
+ return self.out(out)
189
+
190
+
191
+ class LinearAttention(nn.Module):
192
+ def __init__(self, dim, heads=4, dim_head=32):
193
+ super().__init__()
194
+ self.scale = dim_head**-0.5
195
+ self.heads = heads
196
+ hidden_dim = dim_head * heads
197
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
198
+ self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
199
+ self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
200
+ self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
201
+ self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
202
+
203
+ def forward(self, x):
204
+ b, c, h, w = x.shape
205
+ qkv = self.to_qkv(x).chunk(3, dim=1)
206
+ q, k, v = map(
207
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
208
+ )
209
+ # calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
210
+ q = q.softmax(dim=-2)
211
+ # calculate the softmax with respect to rows of k
212
+ k = k.softmax(dim=-1)
213
+ # normalize the values in the attention matrix
214
+ q = q * self.scale
215
+ # dot product of q and v matrices
216
+ context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
217
+ # dot product of context and q
218
+ out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
219
+ # rearrange the output to match the pytorch convention
220
+ out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
221
+ return self.to_out(out)
222
+
223
+
224
+ class PreNorm(nn.Module):
225
+ def __init__(self, dim, fn):
226
+ super().__init__()
227
+ self.fn = fn
228
+ self.norm = nn.GroupNorm(1, dim)
229
+
230
+ def forward(self, x, *args, **kwargs):
231
+ x = self.norm(x)
232
+ return self.fn(x, *args, **kwargs)
models/structure/Advanced_Network_Helpers_3.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ from functools import partial
4
+ import matplotlib.pyplot as plt
5
+ from tqdm.auto import tqdm
6
+ from einops import rearrange
7
+ import torch
8
+ from torch import nn, einsum
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def exists(x):
13
+ return x is not None
14
+
15
+
16
+ def default(val, d):
17
+ if exists(val):
18
+ return val
19
+ return d() if isfunction(d) else d
20
+
21
+
22
+ class Residual(nn.Module):
23
+ def __init__(self, fn):
24
+ super().__init__()
25
+ self.fn = fn
26
+
27
+ def forward(self, x, *args, **kwargs):
28
+ return self.fn(x, *args, **kwargs) + x
29
+
30
+
31
+ def Upsample(dim):
32
+ return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
33
+
34
+
35
+ def Downsample(dim):
36
+ return nn.Conv2d(dim, dim, 4, 2, 1)
37
+
38
+
39
+ class SinusoidalPositionEmbeddings(nn.Module):
40
+ def __init__(self, dim):
41
+ super().__init__()
42
+ self.dim = dim
43
+
44
+ def forward(self, time):
45
+ device = time.device
46
+ half_dim = self.dim // 2
47
+ embeddings = math.log(10000) / (half_dim - 1)
48
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
49
+ embeddings = time[:, None] * embeddings[None, :]
50
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
51
+ return embeddings
52
+
53
+
54
+ class Block(nn.Module):
55
+ def __init__(self, dim, dim_out, groups=8):
56
+ super().__init__()
57
+ self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
58
+ self.norm = nn.GroupNorm(groups, dim_out)
59
+ self.act = nn.SiLU()
60
+
61
+ def forward(self, x, scale_shift=None):
62
+ x = self.proj(x)
63
+ x = self.norm(x)
64
+
65
+ if exists(scale_shift):
66
+ scale, shift = scale_shift
67
+ x = x * (scale + 1) + shift
68
+
69
+ x = self.act(x)
70
+ return x
71
+
72
+
73
+ class ResnetBlock(nn.Module):
74
+ """https://arxiv.org/abs/1512.03385"""
75
+
76
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
77
+ super().__init__()
78
+ self.mlp = (
79
+ nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
80
+ if exists(time_emb_dim)
81
+ else None
82
+ )
83
+
84
+ self.block1 = Block(dim, dim_out, groups=groups)
85
+ self.block2 = Block(dim_out, dim_out, groups=groups)
86
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
87
+
88
+ def forward(self, x, time_emb=None):
89
+ h = self.block1(x)
90
+
91
+ if exists(self.mlp) and exists(time_emb):
92
+ time_emb = self.mlp(time_emb)
93
+ h = rearrange(time_emb, "b c -> b c 1 1") + h
94
+
95
+ h = self.block2(h)
96
+ return h + self.res_conv(x)
97
+
98
+
99
+ class ConvNextBlock(nn.Module):
100
+ """https://arxiv.org/abs/2201.03545"""
101
+
102
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
103
+ super().__init__()
104
+ self.mlp = (
105
+ nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
106
+ if exists(time_emb_dim)
107
+ else None
108
+ )
109
+
110
+ self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
111
+
112
+ self.net = nn.Sequential(
113
+ nn.GroupNorm(1, dim) if norm else nn.Identity(),
114
+ nn.Conv2d(dim, dim_out * mult, 3, padding=1),
115
+ nn.GELU(),
116
+ nn.GroupNorm(1, dim_out * mult),
117
+ nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
118
+ )
119
+
120
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
121
+
122
+ def forward(self, x, time_emb=None):
123
+ h = self.ds_conv(x)
124
+
125
+ if exists(self.mlp) and exists(time_emb):
126
+ assert exists(time_emb), "time embedding must be passed in"
127
+ condition = self.mlp(time_emb)
128
+ h = h + rearrange(condition, "b c -> b c 1 1")
129
+
130
+ h = self.net(h)
131
+ return h + self.res_conv(x)
132
+
133
+
134
+ class Attention(nn.Module):
135
+ def __init__(self, dim, heads=4, dim_head=32):
136
+ super().__init__()
137
+ self.scale = dim_head**-0.5
138
+ self.heads = heads
139
+ hidden_dim = dim_head * heads
140
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
141
+ self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
142
+ self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
143
+ self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
144
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
145
+
146
+ def forward(self, x):
147
+ b, c, h, w = x.shape
148
+
149
+ qkv = self.to_qkv(x).chunk(3, dim=1)
150
+ q, k, v = map(
151
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
152
+ )
153
+ q = q * self.scale
154
+
155
+ sim = einsum("b h d i, b h d j -> b h i j", q, k)
156
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
157
+ attn = sim.softmax(dim=-1)
158
+
159
+ out = einsum("b h i j, b h d j -> b h i d", attn, v)
160
+ out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
161
+
162
+ return self.to_out(out)
163
+
164
+
165
+ class LinearCrossAttention(nn.Module):
166
+ def __init__(self, dim, heads=4, dim_head=32) -> None:
167
+ super().__init__()
168
+ self.scale = dim_head**-0.5
169
+ self.heads = heads
170
+ hidden_dim = dim_head * heads
171
+ self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias=False)
172
+ self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
173
+ self.out = nn.Conv2d(hidden_dim, dim, 1)
174
+
175
+ def forward(self, x, cross_attend):
176
+ b, c, h, w = x.shape
177
+ q = self.to_q(x)
178
+ k, v = self.to_kv(cross_attend).chunk(2, dim=1)
179
+ q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads)
180
+ k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads)
181
+ v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads)
182
+ q = q * self.scale
183
+ sim = einsum("b h d i, b h d j -> b h i j", q, k)
184
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
185
+ attn = sim.softmax(dim=-1)
186
+ out = einsum("b h i j, b h d j -> b h i d", attn, v)
187
+ out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
188
+ return self.out(out)
189
+
190
+
191
+ class LinearAttention(nn.Module):
192
+ def __init__(self, dim, heads=4, dim_head=32):
193
+ super().__init__()
194
+ self.scale = dim_head**-0.5
195
+ self.heads = heads
196
+ hidden_dim = dim_head * heads
197
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
198
+ self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
199
+ self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
200
+ self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
201
+ self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
202
+
203
+ def forward(self, x):
204
+ b, c, h, w = x.shape
205
+ qkv = self.to_qkv(x).chunk(3, dim=1)
206
+ q, k, v = map(
207
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
208
+ )
209
+ # calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
210
+ q = q.softmax(dim=-2)
211
+ # calculate the softmax with respect to rows of k
212
+ k = k.softmax(dim=-1)
213
+ # normalize the values in the attention matrix
214
+ q = q * self.scale
215
+ # dot product of q and v matrices
216
+ context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
217
+ # dot product of context and q
218
+ out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
219
+ # rearrange the output to match the pytorch convention
220
+ out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
221
+ return self.to_out(out)
222
+
223
+
224
+ class PreNorm(nn.Module):
225
+ def __init__(self, dim, fn):
226
+ super().__init__()
227
+ self.fn = fn
228
+ self.norm = nn.GroupNorm(1, dim)
229
+
230
+ def forward(self, x, *args, **kwargs):
231
+ x = self.norm(x)
232
+ return self.fn(x, *args, **kwargs)
models/structure/Unet.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ from functools import partial
4
+ import matplotlib.pyplot as plt
5
+ from tqdm.auto import tqdm
6
+ from einops import rearrange
7
+ import torch
8
+ from torch import nn, einsum
9
+ import torch.nn.functional as F
10
+ from .Advanced_Network_Helpers import *
11
+
12
+
13
+ class Unet(nn.Module):
14
+ def __init__(
15
+ self,
16
+ dim,
17
+ init_dim=None,
18
+ out_dim=None,
19
+ dim_mults=(1, 2, 4, 8),
20
+ channels=3,
21
+ with_time_emb=True,
22
+ resnet_block_groups=8,
23
+ use_convnext=True,
24
+ convnext_mult=2,
25
+ ):
26
+ super().__init__()
27
+
28
+ # determine dimensions
29
+ self.channels = channels # since we are concatenating the images and the conditionings along the channel dimension
30
+
31
+ init_dim = default(init_dim, dim // 3 * 2)
32
+ self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
33
+ self.conditioning_init = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
34
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
35
+ in_out = list(zip(dims[:-1], dims[1:]))
36
+ self.in_out = in_out
37
+
38
+ if use_convnext:
39
+ block_klass = partial(ConvNextBlock, mult=convnext_mult)
40
+ else:
41
+ block_klass = partial(ResnetBlock, groups=resnet_block_groups)
42
+
43
+ # time embeddings
44
+ if with_time_emb:
45
+ time_dim = dim * 4
46
+ self.time_mlp = nn.Sequential(
47
+ SinusoidalPositionEmbeddings(dim),
48
+ nn.Linear(dim, time_dim),
49
+ nn.GELU(),
50
+ nn.Linear(time_dim, time_dim),
51
+ )
52
+ else:
53
+ time_dim = None
54
+ self.time_mlp = None
55
+
56
+ # layers
57
+ self.downs = nn.ModuleList([])
58
+ self.ups = nn.ModuleList([])
59
+ self.conditioning_encoder = nn.ModuleList([])
60
+ num_resolutions = len(in_out)
61
+ self.num_resolutions = num_resolutions
62
+
63
+ # conditioning encoder
64
+ for ind, (dim_in, dim_out) in enumerate(in_out):
65
+ is_last = ind >= (num_resolutions - 1)
66
+
67
+ self.conditioning_encoder.append(
68
+ nn.ModuleList(
69
+ [
70
+ block_klass(dim_in, dim_out),
71
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))),
72
+ Downsample(dim_out) if not is_last else nn.Identity(),
73
+ ]
74
+ )
75
+ )
76
+
77
+ for ind, (dim_in, dim_out) in enumerate(in_out):
78
+ is_last = ind >= (num_resolutions - 1)
79
+
80
+ self.downs.append(
81
+ nn.ModuleList(
82
+ [
83
+ block_klass(dim_in, dim_out, time_emb_dim=time_dim),
84
+ block_klass(dim_out, dim_out, time_emb_dim=time_dim),
85
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))),
86
+ Downsample(dim_out) if not is_last else nn.Identity(),
87
+ ]
88
+ )
89
+ )
90
+
91
+ mid_dim = dims[-1]
92
+
93
+ self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
94
+ self.cross_attention = Residual(PreNorm(mid_dim, LinearCrossAttention(mid_dim)))
95
+ self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
96
+
97
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
98
+ is_last = ind >= (num_resolutions - 1)
99
+ self.ups.append(
100
+ nn.ModuleList(
101
+ [
102
+ block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
103
+ block_klass(dim_in, dim_in, time_emb_dim=time_dim),
104
+ Residual(PreNorm(dim_in, LinearAttention(dim_in))),
105
+ Upsample(dim_in) if not is_last else nn.Identity(),
106
+ ]
107
+ )
108
+ )
109
+
110
+ out_dim = default(out_dim, channels)
111
+ self.final_conv = nn.Sequential(
112
+ block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
113
+ )
114
+
115
+ def forward(self, x, time, implicit_conditioning, explicit_conditioning):
116
+ x = torch.cat((x, explicit_conditioning), dim=1)
117
+ conditioning = torch.cat((implicit_conditioning, explicit_conditioning), dim=1)
118
+ x = self.init_conv(x)
119
+
120
+ conditioning = self.conditioning_init(conditioning)
121
+
122
+ t = self.time_mlp(time) if exists(self.time_mlp) else None
123
+
124
+ h = []
125
+
126
+ # conditioning encoder
127
+
128
+ for block1, attn, downsample in self.conditioning_encoder:
129
+ conditioning = block1(conditioning)
130
+ conditioning = attn(conditioning)
131
+ conditioning = downsample(conditioning)
132
+
133
+ for block1, block2, attn, downsample in self.downs:
134
+ x = block1(x, t)
135
+ x = block2(x, t)
136
+ x = attn(x)
137
+ h.append(x)
138
+ x = downsample(x)
139
+
140
+ # bottleneck
141
+ x = self.mid_block1(x, t)
142
+ x = self.cross_attention(x, conditioning)
143
+ x = self.mid_block2(x, t)
144
+
145
+ for block1, block2, attn, upsample in self.ups:
146
+ x = torch.cat((x, h.pop()), dim=1)
147
+ x = block1(x, t)
148
+ x = block2(x, t)
149
+ x = attn(x)
150
+ x = upsample(x)
151
+
152
+ return self.final_conv(x)
models/structure/Unet_2.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ from functools import partial
4
+ import matplotlib.pyplot as plt
5
+ from tqdm.auto import tqdm
6
+ from einops import rearrange
7
+ import torch
8
+ from torch import nn, einsum
9
+ import torch.nn.functional as F
10
+ from .Advanced_Network_Helpers_2 import *
11
+
12
+
13
+ class Unet(nn.Module):
14
+ def __init__(
15
+ self,
16
+ dim,
17
+ init_dim=None,
18
+ out_dim=None,
19
+ dim_mults=(1, 2, 4, 8),
20
+ channels=3,
21
+ with_time_emb=True,
22
+ resnet_block_groups=8,
23
+ use_convnext=True,
24
+ convnext_mult=2,
25
+ ):
26
+ super().__init__()
27
+
28
+ # determine dimensions
29
+ self.channels = channels # since we are concatenating the images and the conditionings along the channel dimension
30
+
31
+ init_dim = default(init_dim, dim // 3 * 2)
32
+ self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
33
+ self.conditioning_init = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
34
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
35
+ in_out = list(zip(dims[:-1], dims[1:]))
36
+ self.in_out = in_out
37
+
38
+ if use_convnext:
39
+ block_klass = partial(ConvNextBlock, mult=convnext_mult)
40
+ else:
41
+ block_klass = partial(ResnetBlock, groups=resnet_block_groups)
42
+
43
+ # time embeddings
44
+ if with_time_emb:
45
+ time_dim = dim * 4
46
+ self.time_mlp = nn.Sequential(
47
+ SinusoidalPositionEmbeddings(dim),
48
+ nn.Linear(dim, time_dim),
49
+ nn.GELU(),
50
+ nn.Linear(time_dim, time_dim),
51
+ )
52
+ else:
53
+ time_dim = None
54
+ self.time_mlp = None
55
+
56
+ # layers
57
+ self.downs = nn.ModuleList([])
58
+ self.ups = nn.ModuleList([])
59
+ self.conditioning_encoder = nn.ModuleList([])
60
+ num_resolutions = len(in_out)
61
+ self.num_resolutions = num_resolutions
62
+
63
+ # conditioning encoder
64
+ for ind, (dim_in, dim_out) in enumerate(in_out):
65
+ is_last = ind >= (num_resolutions - 1)
66
+
67
+ self.conditioning_encoder.append(
68
+ nn.ModuleList(
69
+ [
70
+ block_klass(dim_in, dim_out),
71
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))),
72
+ Downsample(dim_out) if not is_last else nn.Identity(),
73
+ ]
74
+ )
75
+ )
76
+
77
+ for ind, (dim_in, dim_out) in enumerate(in_out):
78
+ is_last = ind >= (num_resolutions - 1)
79
+
80
+ self.downs.append(
81
+ nn.ModuleList(
82
+ [
83
+ block_klass(dim_in, dim_out, time_emb_dim=time_dim),
84
+ block_klass(dim_out, dim_out, time_emb_dim=time_dim),
85
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))),
86
+ Downsample(dim_out) if not is_last else nn.Identity(),
87
+ ]
88
+ )
89
+ )
90
+
91
+ mid_dim = dims[-1]
92
+
93
+ self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
94
+ self.cross_attention = Residual(PreNorm(mid_dim, LinearCrossAttention(mid_dim)))
95
+ self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
96
+
97
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
98
+ is_last = ind >= (num_resolutions - 1)
99
+ self.ups.append(
100
+ nn.ModuleList(
101
+ [
102
+ block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
103
+ block_klass(dim_in, dim_in, time_emb_dim=time_dim),
104
+ Residual(PreNorm(dim_in, LinearAttention(dim_in))),
105
+ Upsample(dim_in) if not is_last else nn.Identity(),
106
+ ]
107
+ )
108
+ )
109
+
110
+ out_dim = default(out_dim, channels)
111
+ self.final_conv = nn.Sequential(
112
+ block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
113
+ )
114
+
115
+ def forward(self, x, time, implicit_conditioning, explicit_conditioning):
116
+ x = torch.cat((x, explicit_conditioning), dim=1)
117
+ conditioning = torch.cat((implicit_conditioning, explicit_conditioning), dim=1)
118
+ x = self.init_conv(x)
119
+
120
+ conditioning = self.conditioning_init(conditioning)
121
+
122
+ t = self.time_mlp(time) if exists(self.time_mlp) else None
123
+
124
+ h = []
125
+
126
+ # conditioning encoder
127
+
128
+ for block1, attn, downsample in self.conditioning_encoder:
129
+ conditioning = block1(conditioning)
130
+ conditioning = attn(conditioning)
131
+ conditioning = downsample(conditioning)
132
+
133
+ for block1, block2, attn, downsample in self.downs:
134
+ x = block1(x, t)
135
+ x = block2(x, t)
136
+ x = attn(x)
137
+ h.append(x)
138
+ x = downsample(x)
139
+
140
+ # bottleneck
141
+ x = self.mid_block1(x, t)
142
+ x = self.cross_attention(x, conditioning)
143
+ x = self.mid_block2(x, t)
144
+
145
+ for block1, block2, attn, upsample in self.ups:
146
+ x = torch.cat((x, h.pop()), dim=1)
147
+ x = block1(x, t)
148
+ x = block2(x, t)
149
+ x = attn(x)
150
+ x = upsample(x)
151
+
152
+ return self.final_conv(x)
models/structure/Unet_3.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ from functools import partial
4
+ import matplotlib.pyplot as plt
5
+ from tqdm.auto import tqdm
6
+ from einops import rearrange
7
+ import torch
8
+ from torch import nn, einsum
9
+ import torch.nn.functional as F
10
+ from .Advanced_Network_Helpers_3 import *
11
+ from transformers import PreTrainedModel
12
+
13
+
14
+ class Unet(nn.Module):
15
+ def __init__(
16
+ self,
17
+ dim,
18
+ init_dim=None,
19
+ out_dim=None,
20
+ dim_mults=(1, 2, 4, 8),
21
+ channels=3,
22
+ with_time_emb=True,
23
+ resnet_block_groups=8,
24
+ use_convnext=True,
25
+ convnext_mult=2,
26
+ ):
27
+ super().__init__()
28
+
29
+ # determine dimensions
30
+ self.channels = channels # since we are concatenating the images and the conditionings along the channel dimension
31
+
32
+ init_dim = default(init_dim, dim // 3 * 2)
33
+ self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
34
+ self.conditioning_init = nn.Conv2d(self.channels, init_dim, 7, padding=3)
35
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
36
+ in_out = list(zip(dims[:-1], dims[1:]))
37
+ self.in_out = in_out
38
+
39
+ if use_convnext:
40
+ block_klass = partial(ConvNextBlock, mult=convnext_mult)
41
+ else:
42
+ block_klass = partial(ResnetBlock, groups=resnet_block_groups)
43
+
44
+ # time embeddings
45
+ if with_time_emb:
46
+ time_dim = dim * 4
47
+ self.time_mlp = nn.Sequential(
48
+ SinusoidalPositionEmbeddings(dim),
49
+ nn.Linear(dim, time_dim),
50
+ nn.GELU(),
51
+ nn.Linear(time_dim, time_dim),
52
+ )
53
+ else:
54
+ time_dim = None
55
+ self.time_mlp = None
56
+
57
+ # layers
58
+ self.downs = nn.ModuleList([])
59
+ self.ups = nn.ModuleList([])
60
+ self.conditioning_encoder = nn.ModuleList([])
61
+ num_resolutions = len(in_out)
62
+ self.num_resolutions = num_resolutions
63
+
64
+ # conditioning encoder
65
+ for ind, (dim_in, dim_out) in enumerate(in_out):
66
+ is_last = ind >= (num_resolutions - 1)
67
+
68
+ self.conditioning_encoder.append(
69
+ nn.ModuleList(
70
+ [
71
+ block_klass(dim_in, dim_out),
72
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))),
73
+ Downsample(dim_out) if not is_last else nn.Identity(),
74
+ ]
75
+ )
76
+ )
77
+
78
+ for ind, (dim_in, dim_out) in enumerate(in_out):
79
+ is_last = ind >= (num_resolutions - 1)
80
+
81
+ self.downs.append(
82
+ nn.ModuleList(
83
+ [
84
+ block_klass(dim_in, dim_out, time_emb_dim=time_dim),
85
+ block_klass(dim_out, dim_out, time_emb_dim=time_dim),
86
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))),
87
+ Downsample(dim_out) if not is_last else nn.Identity(),
88
+ ]
89
+ )
90
+ )
91
+
92
+ mid_dim = dims[-1]
93
+
94
+ self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
95
+ self.cross_attention_1 = Residual(
96
+ PreNorm(mid_dim, LinearCrossAttention(mid_dim))
97
+ )
98
+ self.cross_attention_2 = Residual(
99
+ PreNorm(mid_dim, LinearCrossAttention(mid_dim))
100
+ )
101
+ self.cross_attention_3 = Residual(
102
+ PreNorm(mid_dim, LinearCrossAttention(mid_dim))
103
+ )
104
+ self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
105
+
106
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
107
+ is_last = ind >= (num_resolutions - 1)
108
+ self.ups.append(
109
+ nn.ModuleList(
110
+ [
111
+ block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
112
+ block_klass(dim_in, dim_in, time_emb_dim=time_dim),
113
+ Residual(PreNorm(dim_in, LinearAttention(dim_in))),
114
+ Upsample(dim_in) if not is_last else nn.Identity(),
115
+ ]
116
+ )
117
+ )
118
+
119
+ out_dim = default(out_dim, channels)
120
+ self.final_conv = nn.Sequential(
121
+ block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
122
+ )
123
+
124
+ def forward(self, x, time, implicit_conditioning, explicit_conditioning):
125
+ x = torch.cat((x, explicit_conditioning), dim=1)
126
+
127
+ x = self.init_conv(x)
128
+
129
+ conditioning = self.conditioning_init(implicit_conditioning)
130
+
131
+ t = self.time_mlp(time) if exists(self.time_mlp) else None
132
+
133
+ h = []
134
+
135
+ # conditioning encoder
136
+
137
+ for block1, attn, downsample in self.conditioning_encoder:
138
+ conditioning = block1(conditioning)
139
+ conditioning = attn(conditioning)
140
+ conditioning = downsample(conditioning)
141
+
142
+ for block1, block2, attn, downsample in self.downs:
143
+ x = block1(x, t)
144
+ x = block2(x, t)
145
+ x = attn(x)
146
+ h.append(x)
147
+ x = downsample(x)
148
+
149
+ # reverse the c list
150
+
151
+ # bottleneck
152
+
153
+ x = self.cross_attention_1(x, conditioning)
154
+ x = self.mid_block1(x, t)
155
+ x = self.cross_attention_2(x, conditioning)
156
+ x = self.mid_block2(x, t)
157
+ x = self.cross_attention_3(x, conditioning)
158
+
159
+ for block1, block2, attn, upsample in self.ups:
160
+ x = torch.cat((x, h.pop()), dim=1)
161
+ x = block1(x, t)
162
+ x = block2(x, t)
163
+ x = attn(x)
164
+ x = upsample(x)
165
+
166
+ return self.final_conv(x)
models/structure/hf_compatible_model.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel
2
+ import math
3
+ from inspect import isfunction
4
+ from functools import partial
5
+ import matplotlib.pyplot as plt
6
+ from tqdm.auto import tqdm
7
+ from einops import rearrange
8
+ import torch
9
+ from torch import nn, einsum
10
+ import torch.nn.functional as F
11
+ from transformers import PreTrainedModel
12
+ from .Advanced_Network_Helpers_3 import *
13
+ import os
14
+
15
+
16
+ class UnetConfig(PretrainedConfig):
17
+ model_type = "unet"
18
+
19
+ def __init__(
20
+ self,
21
+ dim=64,
22
+ init_dim=None,
23
+ out_dim=None,
24
+ dim_mults=(1, 2, 4, 8),
25
+ channels=3,
26
+ with_time_emb=True,
27
+ resnet_block_groups=8,
28
+ use_convnext=True,
29
+ convnext_mult=2,
30
+ **kwargs
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.dim = dim
34
+ self.init_dim = init_dim
35
+ self.out_dim = out_dim
36
+ self.dim_mults = dim_mults
37
+ self.channels = channels
38
+ self.with_time_emb = with_time_emb
39
+ self.resnet_block_groups = resnet_block_groups
40
+ self.use_convnext = use_convnext
41
+ self.convnext_mult = convnext_mult
42
+
43
+
44
+ class Unet(PreTrainedModel):
45
+ config_class = UnetConfig
46
+
47
+ def __init__(
48
+ self,
49
+ config,
50
+ ):
51
+ super().__init__(config)
52
+
53
+ # determine dimensions
54
+ self.channels = (
55
+ config.channels
56
+ ) # since we are concatenating the images and the conditionings along the channel dimension
57
+
58
+ init_dim = default(config.init_dim, config.dim // 3 * 2)
59
+ self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
60
+ self.conditioning_init = nn.Conv2d(self.channels, init_dim, 7, padding=3)
61
+ dims = [init_dim, *map(lambda m: config.dim * m, config.dim_mults)]
62
+ in_out = list(zip(dims[:-1], dims[1:]))
63
+ self.in_out = in_out
64
+
65
+ if config.use_convnext:
66
+ block_klass = partial(ConvNextBlock, mult=config.convnext_mult)
67
+ else:
68
+ block_klass = partial(ResnetBlock, groups=config.resnet_block_groups)
69
+
70
+ # time embeddings
71
+ if config.with_time_emb:
72
+ time_dim = config.dim * 4
73
+ self.time_mlp = nn.Sequential(
74
+ SinusoidalPositionEmbeddings(config.dim),
75
+ nn.Linear(config.dim, time_dim),
76
+ nn.GELU(),
77
+ nn.Linear(time_dim, time_dim),
78
+ )
79
+ else:
80
+ time_dim = None
81
+ self.time_mlp = None
82
+
83
+ # layers
84
+ self.downs = nn.ModuleList([])
85
+ self.ups = nn.ModuleList([])
86
+ self.conditioning_encoder = nn.ModuleList([])
87
+ num_resolutions = len(in_out)
88
+ self.num_resolutions = num_resolutions
89
+
90
+ # conditioning encoder
91
+ for ind, (dim_in, dim_out) in enumerate(in_out):
92
+ is_last = ind >= (num_resolutions - 1)
93
+
94
+ self.conditioning_encoder.append(
95
+ nn.ModuleList(
96
+ [
97
+ block_klass(dim_in, dim_out),
98
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))),
99
+ Downsample(dim_out) if not is_last else nn.Identity(),
100
+ ]
101
+ )
102
+ )
103
+
104
+ for ind, (dim_in, dim_out) in enumerate(in_out):
105
+ is_last = ind >= (num_resolutions - 1)
106
+
107
+ self.downs.append(
108
+ nn.ModuleList(
109
+ [
110
+ block_klass(dim_in, dim_out, time_emb_dim=time_dim),
111
+ block_klass(dim_out, dim_out, time_emb_dim=time_dim),
112
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))),
113
+ Downsample(dim_out) if not is_last else nn.Identity(),
114
+ ]
115
+ )
116
+ )
117
+
118
+ mid_dim = dims[-1]
119
+
120
+ self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
121
+ self.cross_attention_1 = Residual(
122
+ PreNorm(mid_dim, LinearCrossAttention(mid_dim))
123
+ )
124
+ self.cross_attention_2 = Residual(
125
+ PreNorm(mid_dim, LinearCrossAttention(mid_dim))
126
+ )
127
+ self.cross_attention_3 = Residual(
128
+ PreNorm(mid_dim, LinearCrossAttention(mid_dim))
129
+ )
130
+ self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
131
+
132
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
133
+ is_last = ind >= (num_resolutions - 1)
134
+ self.ups.append(
135
+ nn.ModuleList(
136
+ [
137
+ block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
138
+ block_klass(dim_in, dim_in, time_emb_dim=time_dim),
139
+ Residual(PreNorm(dim_in, LinearAttention(dim_in))),
140
+ Upsample(dim_in) if not is_last else nn.Identity(),
141
+ ]
142
+ )
143
+ )
144
+
145
+ out_dim = default(config.out_dim, config.channels)
146
+ self.final_conv = nn.Sequential(
147
+ block_klass(config.dim, config.dim), nn.Conv2d(config.dim, out_dim, 1)
148
+ )
149
+
150
+ def forward(self, x, time, implicit_conditioning, explicit_conditioning):
151
+ x = torch.cat((x, explicit_conditioning), dim=1)
152
+
153
+ x = self.init_conv(x)
154
+
155
+ conditioning = self.conditioning_init(implicit_conditioning)
156
+
157
+ t = self.time_mlp(time) if exists(self.time_mlp) else None
158
+
159
+ h = []
160
+
161
+ # conditioning encoder
162
+
163
+ for block1, attn, downsample in self.conditioning_encoder:
164
+ conditioning = block1(conditioning)
165
+ conditioning = attn(conditioning)
166
+ conditioning = downsample(conditioning)
167
+
168
+ for block1, block2, attn, downsample in self.downs:
169
+ x = block1(x, t)
170
+ x = block2(x, t)
171
+ x = attn(x)
172
+ h.append(x)
173
+ x = downsample(x)
174
+
175
+ # reverse the c list
176
+
177
+ # bottleneck
178
+
179
+ x = self.cross_attention_1(x, conditioning)
180
+ x = self.mid_block1(x, t)
181
+ x = self.cross_attention_2(x, conditioning)
182
+ x = self.mid_block2(x, t)
183
+ x = self.cross_attention_3(x, conditioning)
184
+
185
+ for block1, block2, attn, upsample in self.ups:
186
+ x = torch.cat((x, h.pop()), dim=1)
187
+ x = block1(x, t)
188
+ x = block2(x, t)
189
+ x = attn(x)
190
+ x = upsample(x)
191
+
192
+ return self.final_conv(x)