Spaces:
Running
Running
init commit
Browse files- .gitattributes +1 -0
- .gitignore +40 -0
- app.py +136 -0
- load_model.py +110 -0
- models/structure/Advanced_Network_Helpers.py +255 -0
- models/structure/Advanced_Network_Helpers_2.py +232 -0
- models/structure/Advanced_Network_Helpers_3.py +232 -0
- models/structure/Unet.py +152 -0
- models/structure/Unet_2.py +152 -0
- models/structure/Unet_3.py +166 -0
- models/structure/hf_compatible_model.py +192 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.st filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#.model.pth
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
# Byte-compiled / optimized / DLL files
|
6 |
+
__pycache__/
|
7 |
+
*.py[cod]
|
8 |
+
|
9 |
+
# vim swp files
|
10 |
+
*.swp
|
11 |
+
# caffe/pytorch model files
|
12 |
+
*.pth
|
13 |
+
|
14 |
+
*.pt
|
15 |
+
# json
|
16 |
+
*.json
|
17 |
+
|
18 |
+
*.bin
|
19 |
+
|
20 |
+
*.st
|
21 |
+
|
22 |
+
.models/model-epoch_80.st
|
23 |
+
.history/
|
24 |
+
|
25 |
+
dataset/
|
26 |
+
|
27 |
+
wandb/
|
28 |
+
|
29 |
+
|
30 |
+
.vscode/
|
31 |
+
https://github.com/higumax/sketchKeras-pytorch.git
|
32 |
+
|
33 |
+
.startup.sh
|
34 |
+
|
35 |
+
startup.sh
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
app.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from torchvision import transforms
|
5 |
+
from load_model import sample
|
6 |
+
import torch
|
7 |
+
import random
|
8 |
+
|
9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
device = "mps" if torch.backends.mps.is_available() else device
|
11 |
+
|
12 |
+
image_size = 128
|
13 |
+
upscale = False
|
14 |
+
clicked = False
|
15 |
+
|
16 |
+
|
17 |
+
transform = transforms.Compose(
|
18 |
+
[
|
19 |
+
transforms.Resize((image_size, image_size)),
|
20 |
+
transforms.ToTensor(),
|
21 |
+
transforms.Lambda(lambda t: (t * 2) - 1),
|
22 |
+
]
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
def make_scribbles(sketch, scribbles):
|
27 |
+
# get the value that occurs most often in the scribbles
|
28 |
+
sketch = transforms.Resize((image_size, image_size))(sketch)
|
29 |
+
scribbles = transforms.Resize((image_size, image_size))(scribbles)
|
30 |
+
|
31 |
+
grey_tensor = torch.tensor(0.49803922, device=device)
|
32 |
+
|
33 |
+
grey_tensor = grey_tensor.expand(3, image_size, image_size)
|
34 |
+
|
35 |
+
sketch = transforms.ToTensor()(sketch).to(device)
|
36 |
+
scribbles = transforms.ToTensor()(scribbles).to(device)
|
37 |
+
|
38 |
+
scribble_where_grey_mask = torch.eq(scribbles, grey_tensor)
|
39 |
+
|
40 |
+
merged = torch.where(scribble_where_grey_mask, sketch, scribbles)
|
41 |
+
|
42 |
+
return transforms.Lambda(lambda t: (t * 2) - 1)(sketch), transforms.Lambda(
|
43 |
+
lambda t: (t * 2) - 1
|
44 |
+
)(merged)
|
45 |
+
|
46 |
+
|
47 |
+
def process_images(sketch, scribbles, sampling_steps, is_scribbles, seed_nr, upscale):
|
48 |
+
global clicked
|
49 |
+
clicked = True
|
50 |
+
w, h = sketch.size
|
51 |
+
|
52 |
+
if is_scribbles:
|
53 |
+
sketch, scribbles = make_scribbles(sketch, scribbles)
|
54 |
+
|
55 |
+
else:
|
56 |
+
sketch = transform(sketch.convert("RGB"))
|
57 |
+
scribbles = transform(scribbles.convert("RGB"))
|
58 |
+
|
59 |
+
if upscale:
|
60 |
+
output = transforms.Resize((h, w))(
|
61 |
+
sample(sketch, scribbles, sampling_steps, seed_nr)
|
62 |
+
)
|
63 |
+
clicked = False
|
64 |
+
return output
|
65 |
+
else:
|
66 |
+
output = sample(sketch, scribbles, sampling_steps, seed_nr)
|
67 |
+
clicked = False
|
68 |
+
return output
|
69 |
+
|
70 |
+
|
71 |
+
theme = gr.themes.Monochrome()
|
72 |
+
|
73 |
+
|
74 |
+
with gr.Blocks(theme=theme) as demo:
|
75 |
+
with gr.Row():
|
76 |
+
gr.Markdown(
|
77 |
+
"<h1 style='text-align: center; font-size: 30px;'>Image Inpainting with Conditional Diffusion by MedicAI</h1>"
|
78 |
+
)
|
79 |
+
|
80 |
+
with gr.Row():
|
81 |
+
with gr.Column():
|
82 |
+
sketch_input = gr.Image(type="pil", label="Sketch", height=500)
|
83 |
+
with gr.Column():
|
84 |
+
scribbles_input = gr.Image(type="pil", label="Scribbles", height=500)
|
85 |
+
info = gr.Markdown(
|
86 |
+
"<p style='text-align: center; font-size: 12px;'>"
|
87 |
+
"By default the scribbles are assumed to be merged with the sketch, if they appear on a grey background check the box below. "
|
88 |
+
"</p>"
|
89 |
+
)
|
90 |
+
is_scribbles = gr.Checkbox(label="Is Scribbles", value=False)
|
91 |
+
with gr.Column():
|
92 |
+
output = gr.Image(type="pil", label="Output")
|
93 |
+
upscale_info = gr.Markdown(
|
94 |
+
"<p style='text-align: center; font-size: 12px;'>"
|
95 |
+
f"If you want to stretch the downloadable output, check the box below, the default output of neural networks is {image_size}x{image_size} "
|
96 |
+
"</p>"
|
97 |
+
)
|
98 |
+
upscale_button = gr.Checkbox(label="Stretch", value=False)
|
99 |
+
with gr.Row():
|
100 |
+
with gr.Column():
|
101 |
+
seed_slider = gr.Number(
|
102 |
+
label="Random Seed 🎲",
|
103 |
+
value=random.randint(
|
104 |
+
1,
|
105 |
+
1000,
|
106 |
+
),
|
107 |
+
)
|
108 |
+
|
109 |
+
with gr.Column():
|
110 |
+
sampling_slider = gr.Slider(
|
111 |
+
minimum=1,
|
112 |
+
maximum=250,
|
113 |
+
step=1,
|
114 |
+
label="DDPM Sampling Steps 🔄",
|
115 |
+
value=50,
|
116 |
+
)
|
117 |
+
|
118 |
+
with gr.Row():
|
119 |
+
generate_button = gr.Button(value="Generate", interactive=not clicked)
|
120 |
+
|
121 |
+
generate_button.click(
|
122 |
+
process_images,
|
123 |
+
inputs=[
|
124 |
+
sketch_input,
|
125 |
+
scribbles_input,
|
126 |
+
sampling_slider,
|
127 |
+
is_scribbles,
|
128 |
+
seed_slider,
|
129 |
+
upscale_button,
|
130 |
+
],
|
131 |
+
outputs=output,
|
132 |
+
show_progress=True,
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
demo.launch(server_port=3000, max_threads=1)
|
load_model.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.structure.Unet_3 import Unet
|
2 |
+
from diffusers import DDPMScheduler
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
import glob
|
6 |
+
from tqdm import tqdm
|
7 |
+
from torchvision import transforms
|
8 |
+
import pathlib
|
9 |
+
from torchvision.utils import save_image
|
10 |
+
from safetensors.torch import load_model, save_model
|
11 |
+
|
12 |
+
|
13 |
+
denoising_timesteps = 4000
|
14 |
+
image_size = 128
|
15 |
+
channels = 3
|
16 |
+
|
17 |
+
|
18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
device = "mps" if torch.backends.mps.is_available() else device
|
20 |
+
|
21 |
+
model = Unet(
|
22 |
+
dim=image_size,
|
23 |
+
channels=channels,
|
24 |
+
dim_mults=(1, 2, 4, 8),
|
25 |
+
use_convnext=False,
|
26 |
+
).to(device)
|
27 |
+
|
28 |
+
results_folder = pathlib.Path("models")
|
29 |
+
|
30 |
+
|
31 |
+
checkpoint_files_st = glob.glob(str(results_folder / "model-epoch_*.st"))
|
32 |
+
checkpoint_files_pt = glob.glob(str(results_folder / "model-epoch_*.pt"))
|
33 |
+
|
34 |
+
if checkpoint_files_st:
|
35 |
+
# Sort the list of matching files by modification time (newest first)
|
36 |
+
checkpoint_files_st.sort(key=lambda x: os.path.getmtime(x), reverse=True)
|
37 |
+
# Select the newest file
|
38 |
+
checkpoint_files = checkpoint_files_st[0]
|
39 |
+
# Now, newest_model_file contains the path to the newest "model" file
|
40 |
+
load_model(model, checkpoint_files)
|
41 |
+
model.eval()
|
42 |
+
print("Loaded model from checkpoint", checkpoint_files)
|
43 |
+
|
44 |
+
elif checkpoint_files_pt:
|
45 |
+
# Sort the list of matching files by modification time (newest first)
|
46 |
+
checkpoint_files_pt.sort(key=lambda x: os.path.getmtime(x), reverse=True)
|
47 |
+
# Select the newest file
|
48 |
+
checkpoint_files = checkpoint_files_pt[0]
|
49 |
+
# Now, newest_model_file contains the path to the newest "model" file
|
50 |
+
checkpoint = torch.load(checkpoint_files, map_location=device)
|
51 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
52 |
+
epoch = checkpoint["epoch"]
|
53 |
+
model.eval()
|
54 |
+
print("Loaded model from checkpoint", checkpoint_files)
|
55 |
+
|
56 |
+
if not pathlib.Path(str(results_folder / "model-epoch_*.st")).exists():
|
57 |
+
save_model(model, results_folder / "model-epoch_{}.st".format(epoch))
|
58 |
+
print("Saved model as a safetensor", results_folder)
|
59 |
+
|
60 |
+
else:
|
61 |
+
raise Exception("No model files found in the folder.")
|
62 |
+
|
63 |
+
|
64 |
+
def sample(sketch, scribbles, sampling_steps, seed_nr):
|
65 |
+
torch.manual_seed(seed_nr)
|
66 |
+
|
67 |
+
noise_scheduler = DDPMScheduler(
|
68 |
+
num_train_timesteps=denoising_timesteps, beta_schedule="squaredcos_cap_v2"
|
69 |
+
)
|
70 |
+
noise_scheduler.set_timesteps(sampling_steps, device=device)
|
71 |
+
|
72 |
+
sketch = sketch.to(device)
|
73 |
+
scribbles = scribbles.to(device)
|
74 |
+
|
75 |
+
sketch = sketch.unsqueeze(0)
|
76 |
+
scribbles = scribbles.unsqueeze(0)
|
77 |
+
|
78 |
+
with torch.no_grad():
|
79 |
+
b = sketch.shape[0]
|
80 |
+
|
81 |
+
noise_for_plain = torch.randn_like(sketch, device=device)
|
82 |
+
|
83 |
+
for i, t in tqdm(
|
84 |
+
enumerate(noise_scheduler.timesteps),
|
85 |
+
total=len(noise_scheduler.timesteps),
|
86 |
+
):
|
87 |
+
noise_for_plain = noise_scheduler.scale_model_input(noise_for_plain, t).to(
|
88 |
+
device
|
89 |
+
)
|
90 |
+
|
91 |
+
time = t.expand(
|
92 |
+
b,
|
93 |
+
).to(device)
|
94 |
+
|
95 |
+
plain_noise_pred = model(
|
96 |
+
x=noise_for_plain,
|
97 |
+
time=time,
|
98 |
+
implicit_conditioning=scribbles,
|
99 |
+
explicit_conditioning=sketch,
|
100 |
+
)
|
101 |
+
|
102 |
+
noise_for_plain = noise_scheduler.step(
|
103 |
+
plain_noise_pred,
|
104 |
+
t.long(),
|
105 |
+
noise_for_plain,
|
106 |
+
).prev_sample
|
107 |
+
|
108 |
+
sample = torch.clamp((noise_for_plain / 2) + 0.5, 0, 1)
|
109 |
+
|
110 |
+
return transforms.ToPILImage()(sample[0].cpu())
|
models/structure/Advanced_Network_Helpers.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from inspect import isfunction
|
3 |
+
from functools import partial
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
from einops import rearrange
|
7 |
+
import torch
|
8 |
+
from torch import nn, einsum
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def exists(x):
|
13 |
+
return x is not None
|
14 |
+
|
15 |
+
|
16 |
+
def default(val, d):
|
17 |
+
if exists(val):
|
18 |
+
return val
|
19 |
+
return d() if isfunction(d) else d
|
20 |
+
|
21 |
+
|
22 |
+
class Residual(nn.Module):
|
23 |
+
def __init__(self, fn):
|
24 |
+
super().__init__()
|
25 |
+
self.fn = fn
|
26 |
+
|
27 |
+
def forward(self, x, *args, **kwargs):
|
28 |
+
return self.fn(x, *args, **kwargs) + x
|
29 |
+
|
30 |
+
|
31 |
+
def Upsample(dim):
|
32 |
+
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
33 |
+
|
34 |
+
|
35 |
+
def Downsample(dim):
|
36 |
+
return nn.Conv2d(dim, dim, 4, 2, 1)
|
37 |
+
|
38 |
+
|
39 |
+
class SinusoidalPositionEmbeddings(nn.Module):
|
40 |
+
def __init__(self, dim):
|
41 |
+
super().__init__()
|
42 |
+
self.dim = dim
|
43 |
+
|
44 |
+
def forward(self, time):
|
45 |
+
device = time.device
|
46 |
+
half_dim = self.dim // 2
|
47 |
+
embeddings = math.log(10000) / (half_dim - 1)
|
48 |
+
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
49 |
+
embeddings = time[:, None] * embeddings[None, :]
|
50 |
+
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
|
51 |
+
return embeddings
|
52 |
+
|
53 |
+
|
54 |
+
class Block(nn.Module):
|
55 |
+
def __init__(self, dim, dim_out, groups=8):
|
56 |
+
super().__init__()
|
57 |
+
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
|
58 |
+
self.norm = nn.GroupNorm(groups, dim_out)
|
59 |
+
self.act = nn.SiLU()
|
60 |
+
|
61 |
+
def forward(self, x, scale_shift=None):
|
62 |
+
x = self.proj(x)
|
63 |
+
x = self.norm(x)
|
64 |
+
|
65 |
+
if exists(scale_shift):
|
66 |
+
scale, shift = scale_shift
|
67 |
+
x = x * (scale + 1) + shift
|
68 |
+
|
69 |
+
x = self.act(x)
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
class ResnetBlock(nn.Module):
|
74 |
+
"""https://arxiv.org/abs/1512.03385"""
|
75 |
+
|
76 |
+
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
|
77 |
+
super().__init__()
|
78 |
+
self.mlp = (
|
79 |
+
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
|
80 |
+
if exists(time_emb_dim)
|
81 |
+
else None
|
82 |
+
)
|
83 |
+
|
84 |
+
self.block1 = Block(dim, dim_out, groups=groups)
|
85 |
+
self.block2 = Block(dim_out, dim_out, groups=groups)
|
86 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
87 |
+
|
88 |
+
def forward(self, x, time_emb=None):
|
89 |
+
h = self.block1(x)
|
90 |
+
|
91 |
+
if exists(self.mlp) and exists(time_emb):
|
92 |
+
time_emb = self.mlp(time_emb)
|
93 |
+
h = rearrange(time_emb, "b c -> b c 1 1") + h
|
94 |
+
|
95 |
+
h = self.block2(h)
|
96 |
+
return h + self.res_conv(x)
|
97 |
+
|
98 |
+
|
99 |
+
class ConvNextBlock(nn.Module):
|
100 |
+
"""https://arxiv.org/abs/2201.03545"""
|
101 |
+
|
102 |
+
def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
|
103 |
+
super().__init__()
|
104 |
+
self.mlp = (
|
105 |
+
nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
|
106 |
+
if exists(time_emb_dim)
|
107 |
+
else None
|
108 |
+
)
|
109 |
+
|
110 |
+
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
|
111 |
+
|
112 |
+
self.net = nn.Sequential(
|
113 |
+
nn.GroupNorm(1, dim) if norm else nn.Identity(),
|
114 |
+
nn.Conv2d(dim, dim_out * mult, 3, padding=1),
|
115 |
+
nn.GELU(),
|
116 |
+
nn.GroupNorm(1, dim_out * mult),
|
117 |
+
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
|
118 |
+
)
|
119 |
+
|
120 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
121 |
+
|
122 |
+
def forward(self, x, time_emb=None):
|
123 |
+
h = self.ds_conv(x)
|
124 |
+
|
125 |
+
if exists(self.mlp) and exists(time_emb):
|
126 |
+
assert exists(time_emb), "time embedding must be passed in"
|
127 |
+
condition = self.mlp(time_emb)
|
128 |
+
h = h + rearrange(condition, "b c -> b c 1 1")
|
129 |
+
|
130 |
+
h = self.net(h)
|
131 |
+
return h + self.res_conv(x)
|
132 |
+
|
133 |
+
|
134 |
+
class Attention(nn.Module):
|
135 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
136 |
+
super().__init__()
|
137 |
+
self.scale = dim_head**-0.5
|
138 |
+
self.heads = heads
|
139 |
+
hidden_dim = dim_head * heads
|
140 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
141 |
+
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
142 |
+
self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
143 |
+
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
144 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
145 |
+
|
146 |
+
def forward(self, x, cross_attend=None):
|
147 |
+
b, c, h, w = x.shape
|
148 |
+
|
149 |
+
if cross_attend is not None:
|
150 |
+
assert cross_attend.shape == x.shape
|
151 |
+
|
152 |
+
q_att = self.to_q(x)
|
153 |
+
k_att = self.to_k(cross_attend)
|
154 |
+
v_att = self.to_v(cross_attend)
|
155 |
+
q = rearrange(q_att, "b (h c) x y -> b h c (x y)", h=self.heads)
|
156 |
+
k = rearrange(k_att, "b (h c) x y -> b h c (x y)", h=self.heads)
|
157 |
+
v = rearrange(v_att, "b (h c) x y -> b h c (x y)", h=self.heads)
|
158 |
+
else:
|
159 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
160 |
+
q, k, v = map(
|
161 |
+
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
162 |
+
)
|
163 |
+
q = q * self.scale
|
164 |
+
|
165 |
+
sim = einsum("b h d i, b h d j -> b h i j", q, k)
|
166 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
167 |
+
attn = sim.softmax(dim=-1)
|
168 |
+
|
169 |
+
out = einsum("b h i j, b h d j -> b h i d", attn, v)
|
170 |
+
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
171 |
+
|
172 |
+
return self.to_out(out)
|
173 |
+
|
174 |
+
|
175 |
+
class LinearCrossAttention(nn.Module):
|
176 |
+
def __init__(self, dim, heads=12, dim_head=128) -> None:
|
177 |
+
super().__init__()
|
178 |
+
self.scale = dim_head**-0.5
|
179 |
+
self.heads = heads
|
180 |
+
hidden_dim = dim_head * heads
|
181 |
+
self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias=False)
|
182 |
+
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
183 |
+
self.out = nn.Conv2d(hidden_dim, dim, 1)
|
184 |
+
|
185 |
+
def forward(self, x, cross_attend):
|
186 |
+
b, c, h, w = x.shape
|
187 |
+
q = self.to_q(x)
|
188 |
+
k, v = self.to_kv(cross_attend).chunk(2, dim=1)
|
189 |
+
q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads)
|
190 |
+
k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads)
|
191 |
+
v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads)
|
192 |
+
q = q * self.scale
|
193 |
+
sim = einsum("b h d i, b h d j -> b h i j", q, k)
|
194 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
195 |
+
attn = sim.softmax(dim=-1)
|
196 |
+
out = einsum("b h i j, b h d j -> b h i d", attn, v)
|
197 |
+
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
198 |
+
return self.out(out)
|
199 |
+
|
200 |
+
|
201 |
+
class LinearAttention(nn.Module):
|
202 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
203 |
+
super().__init__()
|
204 |
+
self.scale = dim_head**-0.5
|
205 |
+
self.heads = heads
|
206 |
+
hidden_dim = dim_head * heads
|
207 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
208 |
+
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
209 |
+
self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
210 |
+
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
211 |
+
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
|
212 |
+
|
213 |
+
def forward(self, x, cross_attend=None):
|
214 |
+
b, c, h, w = x.shape
|
215 |
+
if cross_attend is not None:
|
216 |
+
assert (
|
217 |
+
cross_attend.shape == x.shape
|
218 |
+
), f"cross_attend must be same shape as x is {cross_attend.shape} and x is {x.shape}"
|
219 |
+
|
220 |
+
q_att = self.to_q(x)
|
221 |
+
k_att = self.to_k(cross_attend)
|
222 |
+
v_att = self.to_v(cross_attend)
|
223 |
+
q = rearrange(q_att, "b (h c) x y -> b h c (x y)", h=self.heads)
|
224 |
+
k = rearrange(k_att, "b (h c) x y -> b h c (x y)", h=self.heads)
|
225 |
+
v = rearrange(v_att, "b (h c) x y -> b h c (x y)", h=self.heads)
|
226 |
+
|
227 |
+
else:
|
228 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
229 |
+
q, k, v = map(
|
230 |
+
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
231 |
+
)
|
232 |
+
# calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
|
233 |
+
q = q.softmax(dim=-2)
|
234 |
+
# calculate the softmax with respect to rows of k
|
235 |
+
k = k.softmax(dim=-1)
|
236 |
+
# normalize the values in the attention matrix
|
237 |
+
q = q * self.scale
|
238 |
+
# dot product of q and v matrices
|
239 |
+
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
|
240 |
+
# dot product of context and q
|
241 |
+
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
|
242 |
+
# rearrange the output to match the pytorch convention
|
243 |
+
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
|
244 |
+
return self.to_out(out)
|
245 |
+
|
246 |
+
|
247 |
+
class PreNorm(nn.Module):
|
248 |
+
def __init__(self, dim, fn):
|
249 |
+
super().__init__()
|
250 |
+
self.fn = fn
|
251 |
+
self.norm = nn.GroupNorm(1, dim)
|
252 |
+
|
253 |
+
def forward(self, x, *args, **kwargs):
|
254 |
+
x = self.norm(x)
|
255 |
+
return self.fn(x, *args, **kwargs)
|
models/structure/Advanced_Network_Helpers_2.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from inspect import isfunction
|
3 |
+
from functools import partial
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
from einops import rearrange
|
7 |
+
import torch
|
8 |
+
from torch import nn, einsum
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def exists(x):
|
13 |
+
return x is not None
|
14 |
+
|
15 |
+
|
16 |
+
def default(val, d):
|
17 |
+
if exists(val):
|
18 |
+
return val
|
19 |
+
return d() if isfunction(d) else d
|
20 |
+
|
21 |
+
|
22 |
+
class Residual(nn.Module):
|
23 |
+
def __init__(self, fn):
|
24 |
+
super().__init__()
|
25 |
+
self.fn = fn
|
26 |
+
|
27 |
+
def forward(self, x, *args, **kwargs):
|
28 |
+
return self.fn(x, *args, **kwargs) + x
|
29 |
+
|
30 |
+
|
31 |
+
def Upsample(dim):
|
32 |
+
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
33 |
+
|
34 |
+
|
35 |
+
def Downsample(dim):
|
36 |
+
return nn.Conv2d(dim, dim, 4, 2, 1)
|
37 |
+
|
38 |
+
|
39 |
+
class SinusoidalPositionEmbeddings(nn.Module):
|
40 |
+
def __init__(self, dim):
|
41 |
+
super().__init__()
|
42 |
+
self.dim = dim
|
43 |
+
|
44 |
+
def forward(self, time):
|
45 |
+
device = time.device
|
46 |
+
half_dim = self.dim // 2
|
47 |
+
embeddings = math.log(10000) / (half_dim - 1)
|
48 |
+
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
49 |
+
embeddings = time[:, None] * embeddings[None, :]
|
50 |
+
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
|
51 |
+
return embeddings
|
52 |
+
|
53 |
+
|
54 |
+
class Block(nn.Module):
|
55 |
+
def __init__(self, dim, dim_out, groups=8):
|
56 |
+
super().__init__()
|
57 |
+
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
|
58 |
+
self.norm = nn.GroupNorm(groups, dim_out)
|
59 |
+
self.act = nn.SiLU()
|
60 |
+
|
61 |
+
def forward(self, x, scale_shift=None):
|
62 |
+
x = self.proj(x)
|
63 |
+
x = self.norm(x)
|
64 |
+
|
65 |
+
if exists(scale_shift):
|
66 |
+
scale, shift = scale_shift
|
67 |
+
x = x * (scale + 1) + shift
|
68 |
+
|
69 |
+
x = self.act(x)
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
class ResnetBlock(nn.Module):
|
74 |
+
"""https://arxiv.org/abs/1512.03385"""
|
75 |
+
|
76 |
+
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
|
77 |
+
super().__init__()
|
78 |
+
self.mlp = (
|
79 |
+
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
|
80 |
+
if exists(time_emb_dim)
|
81 |
+
else None
|
82 |
+
)
|
83 |
+
|
84 |
+
self.block1 = Block(dim, dim_out, groups=groups)
|
85 |
+
self.block2 = Block(dim_out, dim_out, groups=groups)
|
86 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
87 |
+
|
88 |
+
def forward(self, x, time_emb=None):
|
89 |
+
h = self.block1(x)
|
90 |
+
|
91 |
+
if exists(self.mlp) and exists(time_emb):
|
92 |
+
time_emb = self.mlp(time_emb)
|
93 |
+
h = rearrange(time_emb, "b c -> b c 1 1") + h
|
94 |
+
|
95 |
+
h = self.block2(h)
|
96 |
+
return h + self.res_conv(x)
|
97 |
+
|
98 |
+
|
99 |
+
class ConvNextBlock(nn.Module):
|
100 |
+
"""https://arxiv.org/abs/2201.03545"""
|
101 |
+
|
102 |
+
def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
|
103 |
+
super().__init__()
|
104 |
+
self.mlp = (
|
105 |
+
nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
|
106 |
+
if exists(time_emb_dim)
|
107 |
+
else None
|
108 |
+
)
|
109 |
+
|
110 |
+
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
|
111 |
+
|
112 |
+
self.net = nn.Sequential(
|
113 |
+
nn.GroupNorm(1, dim) if norm else nn.Identity(),
|
114 |
+
nn.Conv2d(dim, dim_out * mult, 3, padding=1),
|
115 |
+
nn.GELU(),
|
116 |
+
nn.GroupNorm(1, dim_out * mult),
|
117 |
+
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
|
118 |
+
)
|
119 |
+
|
120 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
121 |
+
|
122 |
+
def forward(self, x, time_emb=None):
|
123 |
+
h = self.ds_conv(x)
|
124 |
+
|
125 |
+
if exists(self.mlp) and exists(time_emb):
|
126 |
+
assert exists(time_emb), "time embedding must be passed in"
|
127 |
+
condition = self.mlp(time_emb)
|
128 |
+
h = h + rearrange(condition, "b c -> b c 1 1")
|
129 |
+
|
130 |
+
h = self.net(h)
|
131 |
+
return h + self.res_conv(x)
|
132 |
+
|
133 |
+
|
134 |
+
class Attention(nn.Module):
|
135 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
136 |
+
super().__init__()
|
137 |
+
self.scale = dim_head**-0.5
|
138 |
+
self.heads = heads
|
139 |
+
hidden_dim = dim_head * heads
|
140 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
141 |
+
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
142 |
+
self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
143 |
+
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
144 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
b, c, h, w = x.shape
|
148 |
+
|
149 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
150 |
+
q, k, v = map(
|
151 |
+
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
152 |
+
)
|
153 |
+
q = q * self.scale
|
154 |
+
|
155 |
+
sim = einsum("b h d i, b h d j -> b h i j", q, k)
|
156 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
157 |
+
attn = sim.softmax(dim=-1)
|
158 |
+
|
159 |
+
out = einsum("b h i j, b h d j -> b h i d", attn, v)
|
160 |
+
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
161 |
+
|
162 |
+
return self.to_out(out)
|
163 |
+
|
164 |
+
|
165 |
+
class LinearCrossAttention(nn.Module):
|
166 |
+
def __init__(self, dim, heads=4, dim_head=32) -> None:
|
167 |
+
super().__init__()
|
168 |
+
self.scale = dim_head**-0.5
|
169 |
+
self.heads = heads
|
170 |
+
hidden_dim = dim_head * heads
|
171 |
+
self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias=False)
|
172 |
+
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
173 |
+
self.out = nn.Conv2d(hidden_dim, dim, 1)
|
174 |
+
|
175 |
+
def forward(self, x, cross_attend):
|
176 |
+
b, c, h, w = x.shape
|
177 |
+
q = self.to_q(x)
|
178 |
+
k, v = self.to_kv(cross_attend).chunk(2, dim=1)
|
179 |
+
q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads)
|
180 |
+
k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads)
|
181 |
+
v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads)
|
182 |
+
q = q * self.scale
|
183 |
+
sim = einsum("b h d i, b h d j -> b h i j", q, k)
|
184 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
185 |
+
attn = sim.softmax(dim=-1)
|
186 |
+
out = einsum("b h i j, b h d j -> b h i d", attn, v)
|
187 |
+
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
188 |
+
return self.out(out)
|
189 |
+
|
190 |
+
|
191 |
+
class LinearAttention(nn.Module):
|
192 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
193 |
+
super().__init__()
|
194 |
+
self.scale = dim_head**-0.5
|
195 |
+
self.heads = heads
|
196 |
+
hidden_dim = dim_head * heads
|
197 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
198 |
+
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
199 |
+
self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
200 |
+
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
201 |
+
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
b, c, h, w = x.shape
|
205 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
206 |
+
q, k, v = map(
|
207 |
+
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
208 |
+
)
|
209 |
+
# calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
|
210 |
+
q = q.softmax(dim=-2)
|
211 |
+
# calculate the softmax with respect to rows of k
|
212 |
+
k = k.softmax(dim=-1)
|
213 |
+
# normalize the values in the attention matrix
|
214 |
+
q = q * self.scale
|
215 |
+
# dot product of q and v matrices
|
216 |
+
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
|
217 |
+
# dot product of context and q
|
218 |
+
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
|
219 |
+
# rearrange the output to match the pytorch convention
|
220 |
+
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
|
221 |
+
return self.to_out(out)
|
222 |
+
|
223 |
+
|
224 |
+
class PreNorm(nn.Module):
|
225 |
+
def __init__(self, dim, fn):
|
226 |
+
super().__init__()
|
227 |
+
self.fn = fn
|
228 |
+
self.norm = nn.GroupNorm(1, dim)
|
229 |
+
|
230 |
+
def forward(self, x, *args, **kwargs):
|
231 |
+
x = self.norm(x)
|
232 |
+
return self.fn(x, *args, **kwargs)
|
models/structure/Advanced_Network_Helpers_3.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from inspect import isfunction
|
3 |
+
from functools import partial
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
from einops import rearrange
|
7 |
+
import torch
|
8 |
+
from torch import nn, einsum
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def exists(x):
|
13 |
+
return x is not None
|
14 |
+
|
15 |
+
|
16 |
+
def default(val, d):
|
17 |
+
if exists(val):
|
18 |
+
return val
|
19 |
+
return d() if isfunction(d) else d
|
20 |
+
|
21 |
+
|
22 |
+
class Residual(nn.Module):
|
23 |
+
def __init__(self, fn):
|
24 |
+
super().__init__()
|
25 |
+
self.fn = fn
|
26 |
+
|
27 |
+
def forward(self, x, *args, **kwargs):
|
28 |
+
return self.fn(x, *args, **kwargs) + x
|
29 |
+
|
30 |
+
|
31 |
+
def Upsample(dim):
|
32 |
+
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
33 |
+
|
34 |
+
|
35 |
+
def Downsample(dim):
|
36 |
+
return nn.Conv2d(dim, dim, 4, 2, 1)
|
37 |
+
|
38 |
+
|
39 |
+
class SinusoidalPositionEmbeddings(nn.Module):
|
40 |
+
def __init__(self, dim):
|
41 |
+
super().__init__()
|
42 |
+
self.dim = dim
|
43 |
+
|
44 |
+
def forward(self, time):
|
45 |
+
device = time.device
|
46 |
+
half_dim = self.dim // 2
|
47 |
+
embeddings = math.log(10000) / (half_dim - 1)
|
48 |
+
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
49 |
+
embeddings = time[:, None] * embeddings[None, :]
|
50 |
+
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
|
51 |
+
return embeddings
|
52 |
+
|
53 |
+
|
54 |
+
class Block(nn.Module):
|
55 |
+
def __init__(self, dim, dim_out, groups=8):
|
56 |
+
super().__init__()
|
57 |
+
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
|
58 |
+
self.norm = nn.GroupNorm(groups, dim_out)
|
59 |
+
self.act = nn.SiLU()
|
60 |
+
|
61 |
+
def forward(self, x, scale_shift=None):
|
62 |
+
x = self.proj(x)
|
63 |
+
x = self.norm(x)
|
64 |
+
|
65 |
+
if exists(scale_shift):
|
66 |
+
scale, shift = scale_shift
|
67 |
+
x = x * (scale + 1) + shift
|
68 |
+
|
69 |
+
x = self.act(x)
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
class ResnetBlock(nn.Module):
|
74 |
+
"""https://arxiv.org/abs/1512.03385"""
|
75 |
+
|
76 |
+
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
|
77 |
+
super().__init__()
|
78 |
+
self.mlp = (
|
79 |
+
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
|
80 |
+
if exists(time_emb_dim)
|
81 |
+
else None
|
82 |
+
)
|
83 |
+
|
84 |
+
self.block1 = Block(dim, dim_out, groups=groups)
|
85 |
+
self.block2 = Block(dim_out, dim_out, groups=groups)
|
86 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
87 |
+
|
88 |
+
def forward(self, x, time_emb=None):
|
89 |
+
h = self.block1(x)
|
90 |
+
|
91 |
+
if exists(self.mlp) and exists(time_emb):
|
92 |
+
time_emb = self.mlp(time_emb)
|
93 |
+
h = rearrange(time_emb, "b c -> b c 1 1") + h
|
94 |
+
|
95 |
+
h = self.block2(h)
|
96 |
+
return h + self.res_conv(x)
|
97 |
+
|
98 |
+
|
99 |
+
class ConvNextBlock(nn.Module):
|
100 |
+
"""https://arxiv.org/abs/2201.03545"""
|
101 |
+
|
102 |
+
def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
|
103 |
+
super().__init__()
|
104 |
+
self.mlp = (
|
105 |
+
nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
|
106 |
+
if exists(time_emb_dim)
|
107 |
+
else None
|
108 |
+
)
|
109 |
+
|
110 |
+
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
|
111 |
+
|
112 |
+
self.net = nn.Sequential(
|
113 |
+
nn.GroupNorm(1, dim) if norm else nn.Identity(),
|
114 |
+
nn.Conv2d(dim, dim_out * mult, 3, padding=1),
|
115 |
+
nn.GELU(),
|
116 |
+
nn.GroupNorm(1, dim_out * mult),
|
117 |
+
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
|
118 |
+
)
|
119 |
+
|
120 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
121 |
+
|
122 |
+
def forward(self, x, time_emb=None):
|
123 |
+
h = self.ds_conv(x)
|
124 |
+
|
125 |
+
if exists(self.mlp) and exists(time_emb):
|
126 |
+
assert exists(time_emb), "time embedding must be passed in"
|
127 |
+
condition = self.mlp(time_emb)
|
128 |
+
h = h + rearrange(condition, "b c -> b c 1 1")
|
129 |
+
|
130 |
+
h = self.net(h)
|
131 |
+
return h + self.res_conv(x)
|
132 |
+
|
133 |
+
|
134 |
+
class Attention(nn.Module):
|
135 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
136 |
+
super().__init__()
|
137 |
+
self.scale = dim_head**-0.5
|
138 |
+
self.heads = heads
|
139 |
+
hidden_dim = dim_head * heads
|
140 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
141 |
+
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
142 |
+
self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
143 |
+
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
144 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
b, c, h, w = x.shape
|
148 |
+
|
149 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
150 |
+
q, k, v = map(
|
151 |
+
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
152 |
+
)
|
153 |
+
q = q * self.scale
|
154 |
+
|
155 |
+
sim = einsum("b h d i, b h d j -> b h i j", q, k)
|
156 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
157 |
+
attn = sim.softmax(dim=-1)
|
158 |
+
|
159 |
+
out = einsum("b h i j, b h d j -> b h i d", attn, v)
|
160 |
+
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
161 |
+
|
162 |
+
return self.to_out(out)
|
163 |
+
|
164 |
+
|
165 |
+
class LinearCrossAttention(nn.Module):
|
166 |
+
def __init__(self, dim, heads=4, dim_head=32) -> None:
|
167 |
+
super().__init__()
|
168 |
+
self.scale = dim_head**-0.5
|
169 |
+
self.heads = heads
|
170 |
+
hidden_dim = dim_head * heads
|
171 |
+
self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias=False)
|
172 |
+
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
173 |
+
self.out = nn.Conv2d(hidden_dim, dim, 1)
|
174 |
+
|
175 |
+
def forward(self, x, cross_attend):
|
176 |
+
b, c, h, w = x.shape
|
177 |
+
q = self.to_q(x)
|
178 |
+
k, v = self.to_kv(cross_attend).chunk(2, dim=1)
|
179 |
+
q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads)
|
180 |
+
k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads)
|
181 |
+
v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads)
|
182 |
+
q = q * self.scale
|
183 |
+
sim = einsum("b h d i, b h d j -> b h i j", q, k)
|
184 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
185 |
+
attn = sim.softmax(dim=-1)
|
186 |
+
out = einsum("b h i j, b h d j -> b h i d", attn, v)
|
187 |
+
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
188 |
+
return self.out(out)
|
189 |
+
|
190 |
+
|
191 |
+
class LinearAttention(nn.Module):
|
192 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
193 |
+
super().__init__()
|
194 |
+
self.scale = dim_head**-0.5
|
195 |
+
self.heads = heads
|
196 |
+
hidden_dim = dim_head * heads
|
197 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
198 |
+
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
199 |
+
self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
200 |
+
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
201 |
+
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
b, c, h, w = x.shape
|
205 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
206 |
+
q, k, v = map(
|
207 |
+
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
208 |
+
)
|
209 |
+
# calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
|
210 |
+
q = q.softmax(dim=-2)
|
211 |
+
# calculate the softmax with respect to rows of k
|
212 |
+
k = k.softmax(dim=-1)
|
213 |
+
# normalize the values in the attention matrix
|
214 |
+
q = q * self.scale
|
215 |
+
# dot product of q and v matrices
|
216 |
+
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
|
217 |
+
# dot product of context and q
|
218 |
+
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
|
219 |
+
# rearrange the output to match the pytorch convention
|
220 |
+
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
|
221 |
+
return self.to_out(out)
|
222 |
+
|
223 |
+
|
224 |
+
class PreNorm(nn.Module):
|
225 |
+
def __init__(self, dim, fn):
|
226 |
+
super().__init__()
|
227 |
+
self.fn = fn
|
228 |
+
self.norm = nn.GroupNorm(1, dim)
|
229 |
+
|
230 |
+
def forward(self, x, *args, **kwargs):
|
231 |
+
x = self.norm(x)
|
232 |
+
return self.fn(x, *args, **kwargs)
|
models/structure/Unet.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from inspect import isfunction
|
3 |
+
from functools import partial
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
from einops import rearrange
|
7 |
+
import torch
|
8 |
+
from torch import nn, einsum
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from .Advanced_Network_Helpers import *
|
11 |
+
|
12 |
+
|
13 |
+
class Unet(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
dim,
|
17 |
+
init_dim=None,
|
18 |
+
out_dim=None,
|
19 |
+
dim_mults=(1, 2, 4, 8),
|
20 |
+
channels=3,
|
21 |
+
with_time_emb=True,
|
22 |
+
resnet_block_groups=8,
|
23 |
+
use_convnext=True,
|
24 |
+
convnext_mult=2,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
# determine dimensions
|
29 |
+
self.channels = channels # since we are concatenating the images and the conditionings along the channel dimension
|
30 |
+
|
31 |
+
init_dim = default(init_dim, dim // 3 * 2)
|
32 |
+
self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
|
33 |
+
self.conditioning_init = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
|
34 |
+
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
35 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
36 |
+
self.in_out = in_out
|
37 |
+
|
38 |
+
if use_convnext:
|
39 |
+
block_klass = partial(ConvNextBlock, mult=convnext_mult)
|
40 |
+
else:
|
41 |
+
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
|
42 |
+
|
43 |
+
# time embeddings
|
44 |
+
if with_time_emb:
|
45 |
+
time_dim = dim * 4
|
46 |
+
self.time_mlp = nn.Sequential(
|
47 |
+
SinusoidalPositionEmbeddings(dim),
|
48 |
+
nn.Linear(dim, time_dim),
|
49 |
+
nn.GELU(),
|
50 |
+
nn.Linear(time_dim, time_dim),
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
time_dim = None
|
54 |
+
self.time_mlp = None
|
55 |
+
|
56 |
+
# layers
|
57 |
+
self.downs = nn.ModuleList([])
|
58 |
+
self.ups = nn.ModuleList([])
|
59 |
+
self.conditioning_encoder = nn.ModuleList([])
|
60 |
+
num_resolutions = len(in_out)
|
61 |
+
self.num_resolutions = num_resolutions
|
62 |
+
|
63 |
+
# conditioning encoder
|
64 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
65 |
+
is_last = ind >= (num_resolutions - 1)
|
66 |
+
|
67 |
+
self.conditioning_encoder.append(
|
68 |
+
nn.ModuleList(
|
69 |
+
[
|
70 |
+
block_klass(dim_in, dim_out),
|
71 |
+
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
72 |
+
Downsample(dim_out) if not is_last else nn.Identity(),
|
73 |
+
]
|
74 |
+
)
|
75 |
+
)
|
76 |
+
|
77 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
78 |
+
is_last = ind >= (num_resolutions - 1)
|
79 |
+
|
80 |
+
self.downs.append(
|
81 |
+
nn.ModuleList(
|
82 |
+
[
|
83 |
+
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
|
84 |
+
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
|
85 |
+
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
86 |
+
Downsample(dim_out) if not is_last else nn.Identity(),
|
87 |
+
]
|
88 |
+
)
|
89 |
+
)
|
90 |
+
|
91 |
+
mid_dim = dims[-1]
|
92 |
+
|
93 |
+
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
94 |
+
self.cross_attention = Residual(PreNorm(mid_dim, LinearCrossAttention(mid_dim)))
|
95 |
+
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
96 |
+
|
97 |
+
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
98 |
+
is_last = ind >= (num_resolutions - 1)
|
99 |
+
self.ups.append(
|
100 |
+
nn.ModuleList(
|
101 |
+
[
|
102 |
+
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
|
103 |
+
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
|
104 |
+
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
|
105 |
+
Upsample(dim_in) if not is_last else nn.Identity(),
|
106 |
+
]
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
out_dim = default(out_dim, channels)
|
111 |
+
self.final_conv = nn.Sequential(
|
112 |
+
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
|
113 |
+
)
|
114 |
+
|
115 |
+
def forward(self, x, time, implicit_conditioning, explicit_conditioning):
|
116 |
+
x = torch.cat((x, explicit_conditioning), dim=1)
|
117 |
+
conditioning = torch.cat((implicit_conditioning, explicit_conditioning), dim=1)
|
118 |
+
x = self.init_conv(x)
|
119 |
+
|
120 |
+
conditioning = self.conditioning_init(conditioning)
|
121 |
+
|
122 |
+
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
123 |
+
|
124 |
+
h = []
|
125 |
+
|
126 |
+
# conditioning encoder
|
127 |
+
|
128 |
+
for block1, attn, downsample in self.conditioning_encoder:
|
129 |
+
conditioning = block1(conditioning)
|
130 |
+
conditioning = attn(conditioning)
|
131 |
+
conditioning = downsample(conditioning)
|
132 |
+
|
133 |
+
for block1, block2, attn, downsample in self.downs:
|
134 |
+
x = block1(x, t)
|
135 |
+
x = block2(x, t)
|
136 |
+
x = attn(x)
|
137 |
+
h.append(x)
|
138 |
+
x = downsample(x)
|
139 |
+
|
140 |
+
# bottleneck
|
141 |
+
x = self.mid_block1(x, t)
|
142 |
+
x = self.cross_attention(x, conditioning)
|
143 |
+
x = self.mid_block2(x, t)
|
144 |
+
|
145 |
+
for block1, block2, attn, upsample in self.ups:
|
146 |
+
x = torch.cat((x, h.pop()), dim=1)
|
147 |
+
x = block1(x, t)
|
148 |
+
x = block2(x, t)
|
149 |
+
x = attn(x)
|
150 |
+
x = upsample(x)
|
151 |
+
|
152 |
+
return self.final_conv(x)
|
models/structure/Unet_2.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from inspect import isfunction
|
3 |
+
from functools import partial
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
from einops import rearrange
|
7 |
+
import torch
|
8 |
+
from torch import nn, einsum
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from .Advanced_Network_Helpers_2 import *
|
11 |
+
|
12 |
+
|
13 |
+
class Unet(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
dim,
|
17 |
+
init_dim=None,
|
18 |
+
out_dim=None,
|
19 |
+
dim_mults=(1, 2, 4, 8),
|
20 |
+
channels=3,
|
21 |
+
with_time_emb=True,
|
22 |
+
resnet_block_groups=8,
|
23 |
+
use_convnext=True,
|
24 |
+
convnext_mult=2,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
# determine dimensions
|
29 |
+
self.channels = channels # since we are concatenating the images and the conditionings along the channel dimension
|
30 |
+
|
31 |
+
init_dim = default(init_dim, dim // 3 * 2)
|
32 |
+
self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
|
33 |
+
self.conditioning_init = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
|
34 |
+
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
35 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
36 |
+
self.in_out = in_out
|
37 |
+
|
38 |
+
if use_convnext:
|
39 |
+
block_klass = partial(ConvNextBlock, mult=convnext_mult)
|
40 |
+
else:
|
41 |
+
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
|
42 |
+
|
43 |
+
# time embeddings
|
44 |
+
if with_time_emb:
|
45 |
+
time_dim = dim * 4
|
46 |
+
self.time_mlp = nn.Sequential(
|
47 |
+
SinusoidalPositionEmbeddings(dim),
|
48 |
+
nn.Linear(dim, time_dim),
|
49 |
+
nn.GELU(),
|
50 |
+
nn.Linear(time_dim, time_dim),
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
time_dim = None
|
54 |
+
self.time_mlp = None
|
55 |
+
|
56 |
+
# layers
|
57 |
+
self.downs = nn.ModuleList([])
|
58 |
+
self.ups = nn.ModuleList([])
|
59 |
+
self.conditioning_encoder = nn.ModuleList([])
|
60 |
+
num_resolutions = len(in_out)
|
61 |
+
self.num_resolutions = num_resolutions
|
62 |
+
|
63 |
+
# conditioning encoder
|
64 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
65 |
+
is_last = ind >= (num_resolutions - 1)
|
66 |
+
|
67 |
+
self.conditioning_encoder.append(
|
68 |
+
nn.ModuleList(
|
69 |
+
[
|
70 |
+
block_klass(dim_in, dim_out),
|
71 |
+
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
72 |
+
Downsample(dim_out) if not is_last else nn.Identity(),
|
73 |
+
]
|
74 |
+
)
|
75 |
+
)
|
76 |
+
|
77 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
78 |
+
is_last = ind >= (num_resolutions - 1)
|
79 |
+
|
80 |
+
self.downs.append(
|
81 |
+
nn.ModuleList(
|
82 |
+
[
|
83 |
+
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
|
84 |
+
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
|
85 |
+
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
86 |
+
Downsample(dim_out) if not is_last else nn.Identity(),
|
87 |
+
]
|
88 |
+
)
|
89 |
+
)
|
90 |
+
|
91 |
+
mid_dim = dims[-1]
|
92 |
+
|
93 |
+
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
94 |
+
self.cross_attention = Residual(PreNorm(mid_dim, LinearCrossAttention(mid_dim)))
|
95 |
+
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
96 |
+
|
97 |
+
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
98 |
+
is_last = ind >= (num_resolutions - 1)
|
99 |
+
self.ups.append(
|
100 |
+
nn.ModuleList(
|
101 |
+
[
|
102 |
+
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
|
103 |
+
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
|
104 |
+
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
|
105 |
+
Upsample(dim_in) if not is_last else nn.Identity(),
|
106 |
+
]
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
out_dim = default(out_dim, channels)
|
111 |
+
self.final_conv = nn.Sequential(
|
112 |
+
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
|
113 |
+
)
|
114 |
+
|
115 |
+
def forward(self, x, time, implicit_conditioning, explicit_conditioning):
|
116 |
+
x = torch.cat((x, explicit_conditioning), dim=1)
|
117 |
+
conditioning = torch.cat((implicit_conditioning, explicit_conditioning), dim=1)
|
118 |
+
x = self.init_conv(x)
|
119 |
+
|
120 |
+
conditioning = self.conditioning_init(conditioning)
|
121 |
+
|
122 |
+
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
123 |
+
|
124 |
+
h = []
|
125 |
+
|
126 |
+
# conditioning encoder
|
127 |
+
|
128 |
+
for block1, attn, downsample in self.conditioning_encoder:
|
129 |
+
conditioning = block1(conditioning)
|
130 |
+
conditioning = attn(conditioning)
|
131 |
+
conditioning = downsample(conditioning)
|
132 |
+
|
133 |
+
for block1, block2, attn, downsample in self.downs:
|
134 |
+
x = block1(x, t)
|
135 |
+
x = block2(x, t)
|
136 |
+
x = attn(x)
|
137 |
+
h.append(x)
|
138 |
+
x = downsample(x)
|
139 |
+
|
140 |
+
# bottleneck
|
141 |
+
x = self.mid_block1(x, t)
|
142 |
+
x = self.cross_attention(x, conditioning)
|
143 |
+
x = self.mid_block2(x, t)
|
144 |
+
|
145 |
+
for block1, block2, attn, upsample in self.ups:
|
146 |
+
x = torch.cat((x, h.pop()), dim=1)
|
147 |
+
x = block1(x, t)
|
148 |
+
x = block2(x, t)
|
149 |
+
x = attn(x)
|
150 |
+
x = upsample(x)
|
151 |
+
|
152 |
+
return self.final_conv(x)
|
models/structure/Unet_3.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from inspect import isfunction
|
3 |
+
from functools import partial
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
from einops import rearrange
|
7 |
+
import torch
|
8 |
+
from torch import nn, einsum
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from .Advanced_Network_Helpers_3 import *
|
11 |
+
from transformers import PreTrainedModel
|
12 |
+
|
13 |
+
|
14 |
+
class Unet(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
dim,
|
18 |
+
init_dim=None,
|
19 |
+
out_dim=None,
|
20 |
+
dim_mults=(1, 2, 4, 8),
|
21 |
+
channels=3,
|
22 |
+
with_time_emb=True,
|
23 |
+
resnet_block_groups=8,
|
24 |
+
use_convnext=True,
|
25 |
+
convnext_mult=2,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
# determine dimensions
|
30 |
+
self.channels = channels # since we are concatenating the images and the conditionings along the channel dimension
|
31 |
+
|
32 |
+
init_dim = default(init_dim, dim // 3 * 2)
|
33 |
+
self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
|
34 |
+
self.conditioning_init = nn.Conv2d(self.channels, init_dim, 7, padding=3)
|
35 |
+
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
36 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
37 |
+
self.in_out = in_out
|
38 |
+
|
39 |
+
if use_convnext:
|
40 |
+
block_klass = partial(ConvNextBlock, mult=convnext_mult)
|
41 |
+
else:
|
42 |
+
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
|
43 |
+
|
44 |
+
# time embeddings
|
45 |
+
if with_time_emb:
|
46 |
+
time_dim = dim * 4
|
47 |
+
self.time_mlp = nn.Sequential(
|
48 |
+
SinusoidalPositionEmbeddings(dim),
|
49 |
+
nn.Linear(dim, time_dim),
|
50 |
+
nn.GELU(),
|
51 |
+
nn.Linear(time_dim, time_dim),
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
time_dim = None
|
55 |
+
self.time_mlp = None
|
56 |
+
|
57 |
+
# layers
|
58 |
+
self.downs = nn.ModuleList([])
|
59 |
+
self.ups = nn.ModuleList([])
|
60 |
+
self.conditioning_encoder = nn.ModuleList([])
|
61 |
+
num_resolutions = len(in_out)
|
62 |
+
self.num_resolutions = num_resolutions
|
63 |
+
|
64 |
+
# conditioning encoder
|
65 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
66 |
+
is_last = ind >= (num_resolutions - 1)
|
67 |
+
|
68 |
+
self.conditioning_encoder.append(
|
69 |
+
nn.ModuleList(
|
70 |
+
[
|
71 |
+
block_klass(dim_in, dim_out),
|
72 |
+
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
73 |
+
Downsample(dim_out) if not is_last else nn.Identity(),
|
74 |
+
]
|
75 |
+
)
|
76 |
+
)
|
77 |
+
|
78 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
79 |
+
is_last = ind >= (num_resolutions - 1)
|
80 |
+
|
81 |
+
self.downs.append(
|
82 |
+
nn.ModuleList(
|
83 |
+
[
|
84 |
+
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
|
85 |
+
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
|
86 |
+
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
87 |
+
Downsample(dim_out) if not is_last else nn.Identity(),
|
88 |
+
]
|
89 |
+
)
|
90 |
+
)
|
91 |
+
|
92 |
+
mid_dim = dims[-1]
|
93 |
+
|
94 |
+
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
95 |
+
self.cross_attention_1 = Residual(
|
96 |
+
PreNorm(mid_dim, LinearCrossAttention(mid_dim))
|
97 |
+
)
|
98 |
+
self.cross_attention_2 = Residual(
|
99 |
+
PreNorm(mid_dim, LinearCrossAttention(mid_dim))
|
100 |
+
)
|
101 |
+
self.cross_attention_3 = Residual(
|
102 |
+
PreNorm(mid_dim, LinearCrossAttention(mid_dim))
|
103 |
+
)
|
104 |
+
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
105 |
+
|
106 |
+
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
107 |
+
is_last = ind >= (num_resolutions - 1)
|
108 |
+
self.ups.append(
|
109 |
+
nn.ModuleList(
|
110 |
+
[
|
111 |
+
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
|
112 |
+
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
|
113 |
+
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
|
114 |
+
Upsample(dim_in) if not is_last else nn.Identity(),
|
115 |
+
]
|
116 |
+
)
|
117 |
+
)
|
118 |
+
|
119 |
+
out_dim = default(out_dim, channels)
|
120 |
+
self.final_conv = nn.Sequential(
|
121 |
+
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
|
122 |
+
)
|
123 |
+
|
124 |
+
def forward(self, x, time, implicit_conditioning, explicit_conditioning):
|
125 |
+
x = torch.cat((x, explicit_conditioning), dim=1)
|
126 |
+
|
127 |
+
x = self.init_conv(x)
|
128 |
+
|
129 |
+
conditioning = self.conditioning_init(implicit_conditioning)
|
130 |
+
|
131 |
+
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
132 |
+
|
133 |
+
h = []
|
134 |
+
|
135 |
+
# conditioning encoder
|
136 |
+
|
137 |
+
for block1, attn, downsample in self.conditioning_encoder:
|
138 |
+
conditioning = block1(conditioning)
|
139 |
+
conditioning = attn(conditioning)
|
140 |
+
conditioning = downsample(conditioning)
|
141 |
+
|
142 |
+
for block1, block2, attn, downsample in self.downs:
|
143 |
+
x = block1(x, t)
|
144 |
+
x = block2(x, t)
|
145 |
+
x = attn(x)
|
146 |
+
h.append(x)
|
147 |
+
x = downsample(x)
|
148 |
+
|
149 |
+
# reverse the c list
|
150 |
+
|
151 |
+
# bottleneck
|
152 |
+
|
153 |
+
x = self.cross_attention_1(x, conditioning)
|
154 |
+
x = self.mid_block1(x, t)
|
155 |
+
x = self.cross_attention_2(x, conditioning)
|
156 |
+
x = self.mid_block2(x, t)
|
157 |
+
x = self.cross_attention_3(x, conditioning)
|
158 |
+
|
159 |
+
for block1, block2, attn, upsample in self.ups:
|
160 |
+
x = torch.cat((x, h.pop()), dim=1)
|
161 |
+
x = block1(x, t)
|
162 |
+
x = block2(x, t)
|
163 |
+
x = attn(x)
|
164 |
+
x = upsample(x)
|
165 |
+
|
166 |
+
return self.final_conv(x)
|
models/structure/hf_compatible_model.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
2 |
+
import math
|
3 |
+
from inspect import isfunction
|
4 |
+
from functools import partial
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
from einops import rearrange
|
8 |
+
import torch
|
9 |
+
from torch import nn, einsum
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from transformers import PreTrainedModel
|
12 |
+
from .Advanced_Network_Helpers_3 import *
|
13 |
+
import os
|
14 |
+
|
15 |
+
|
16 |
+
class UnetConfig(PretrainedConfig):
|
17 |
+
model_type = "unet"
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
dim=64,
|
22 |
+
init_dim=None,
|
23 |
+
out_dim=None,
|
24 |
+
dim_mults=(1, 2, 4, 8),
|
25 |
+
channels=3,
|
26 |
+
with_time_emb=True,
|
27 |
+
resnet_block_groups=8,
|
28 |
+
use_convnext=True,
|
29 |
+
convnext_mult=2,
|
30 |
+
**kwargs
|
31 |
+
):
|
32 |
+
super().__init__(**kwargs)
|
33 |
+
self.dim = dim
|
34 |
+
self.init_dim = init_dim
|
35 |
+
self.out_dim = out_dim
|
36 |
+
self.dim_mults = dim_mults
|
37 |
+
self.channels = channels
|
38 |
+
self.with_time_emb = with_time_emb
|
39 |
+
self.resnet_block_groups = resnet_block_groups
|
40 |
+
self.use_convnext = use_convnext
|
41 |
+
self.convnext_mult = convnext_mult
|
42 |
+
|
43 |
+
|
44 |
+
class Unet(PreTrainedModel):
|
45 |
+
config_class = UnetConfig
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
config,
|
50 |
+
):
|
51 |
+
super().__init__(config)
|
52 |
+
|
53 |
+
# determine dimensions
|
54 |
+
self.channels = (
|
55 |
+
config.channels
|
56 |
+
) # since we are concatenating the images and the conditionings along the channel dimension
|
57 |
+
|
58 |
+
init_dim = default(config.init_dim, config.dim // 3 * 2)
|
59 |
+
self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
|
60 |
+
self.conditioning_init = nn.Conv2d(self.channels, init_dim, 7, padding=3)
|
61 |
+
dims = [init_dim, *map(lambda m: config.dim * m, config.dim_mults)]
|
62 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
63 |
+
self.in_out = in_out
|
64 |
+
|
65 |
+
if config.use_convnext:
|
66 |
+
block_klass = partial(ConvNextBlock, mult=config.convnext_mult)
|
67 |
+
else:
|
68 |
+
block_klass = partial(ResnetBlock, groups=config.resnet_block_groups)
|
69 |
+
|
70 |
+
# time embeddings
|
71 |
+
if config.with_time_emb:
|
72 |
+
time_dim = config.dim * 4
|
73 |
+
self.time_mlp = nn.Sequential(
|
74 |
+
SinusoidalPositionEmbeddings(config.dim),
|
75 |
+
nn.Linear(config.dim, time_dim),
|
76 |
+
nn.GELU(),
|
77 |
+
nn.Linear(time_dim, time_dim),
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
time_dim = None
|
81 |
+
self.time_mlp = None
|
82 |
+
|
83 |
+
# layers
|
84 |
+
self.downs = nn.ModuleList([])
|
85 |
+
self.ups = nn.ModuleList([])
|
86 |
+
self.conditioning_encoder = nn.ModuleList([])
|
87 |
+
num_resolutions = len(in_out)
|
88 |
+
self.num_resolutions = num_resolutions
|
89 |
+
|
90 |
+
# conditioning encoder
|
91 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
92 |
+
is_last = ind >= (num_resolutions - 1)
|
93 |
+
|
94 |
+
self.conditioning_encoder.append(
|
95 |
+
nn.ModuleList(
|
96 |
+
[
|
97 |
+
block_klass(dim_in, dim_out),
|
98 |
+
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
99 |
+
Downsample(dim_out) if not is_last else nn.Identity(),
|
100 |
+
]
|
101 |
+
)
|
102 |
+
)
|
103 |
+
|
104 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
105 |
+
is_last = ind >= (num_resolutions - 1)
|
106 |
+
|
107 |
+
self.downs.append(
|
108 |
+
nn.ModuleList(
|
109 |
+
[
|
110 |
+
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
|
111 |
+
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
|
112 |
+
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
113 |
+
Downsample(dim_out) if not is_last else nn.Identity(),
|
114 |
+
]
|
115 |
+
)
|
116 |
+
)
|
117 |
+
|
118 |
+
mid_dim = dims[-1]
|
119 |
+
|
120 |
+
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
121 |
+
self.cross_attention_1 = Residual(
|
122 |
+
PreNorm(mid_dim, LinearCrossAttention(mid_dim))
|
123 |
+
)
|
124 |
+
self.cross_attention_2 = Residual(
|
125 |
+
PreNorm(mid_dim, LinearCrossAttention(mid_dim))
|
126 |
+
)
|
127 |
+
self.cross_attention_3 = Residual(
|
128 |
+
PreNorm(mid_dim, LinearCrossAttention(mid_dim))
|
129 |
+
)
|
130 |
+
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
131 |
+
|
132 |
+
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
133 |
+
is_last = ind >= (num_resolutions - 1)
|
134 |
+
self.ups.append(
|
135 |
+
nn.ModuleList(
|
136 |
+
[
|
137 |
+
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
|
138 |
+
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
|
139 |
+
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
|
140 |
+
Upsample(dim_in) if not is_last else nn.Identity(),
|
141 |
+
]
|
142 |
+
)
|
143 |
+
)
|
144 |
+
|
145 |
+
out_dim = default(config.out_dim, config.channels)
|
146 |
+
self.final_conv = nn.Sequential(
|
147 |
+
block_klass(config.dim, config.dim), nn.Conv2d(config.dim, out_dim, 1)
|
148 |
+
)
|
149 |
+
|
150 |
+
def forward(self, x, time, implicit_conditioning, explicit_conditioning):
|
151 |
+
x = torch.cat((x, explicit_conditioning), dim=1)
|
152 |
+
|
153 |
+
x = self.init_conv(x)
|
154 |
+
|
155 |
+
conditioning = self.conditioning_init(implicit_conditioning)
|
156 |
+
|
157 |
+
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
158 |
+
|
159 |
+
h = []
|
160 |
+
|
161 |
+
# conditioning encoder
|
162 |
+
|
163 |
+
for block1, attn, downsample in self.conditioning_encoder:
|
164 |
+
conditioning = block1(conditioning)
|
165 |
+
conditioning = attn(conditioning)
|
166 |
+
conditioning = downsample(conditioning)
|
167 |
+
|
168 |
+
for block1, block2, attn, downsample in self.downs:
|
169 |
+
x = block1(x, t)
|
170 |
+
x = block2(x, t)
|
171 |
+
x = attn(x)
|
172 |
+
h.append(x)
|
173 |
+
x = downsample(x)
|
174 |
+
|
175 |
+
# reverse the c list
|
176 |
+
|
177 |
+
# bottleneck
|
178 |
+
|
179 |
+
x = self.cross_attention_1(x, conditioning)
|
180 |
+
x = self.mid_block1(x, t)
|
181 |
+
x = self.cross_attention_2(x, conditioning)
|
182 |
+
x = self.mid_block2(x, t)
|
183 |
+
x = self.cross_attention_3(x, conditioning)
|
184 |
+
|
185 |
+
for block1, block2, attn, upsample in self.ups:
|
186 |
+
x = torch.cat((x, h.pop()), dim=1)
|
187 |
+
x = block1(x, t)
|
188 |
+
x = block2(x, t)
|
189 |
+
x = attn(x)
|
190 |
+
x = upsample(x)
|
191 |
+
|
192 |
+
return self.final_conv(x)
|