Spaces:
Running
Running
fixed code readability
Browse files- app.py +1 -3
- load_model.py +1 -3
- models/structure/{Unet_3.py → Advanced_Conditional_Unet.py} +2 -8
- models/structure/Advanced_Network_Helpers.py +11 -34
- models/structure/Advanced_Network_Helpers_2.py +0 -232
- models/structure/Advanced_Network_Helpers_3.py +0 -232
- models/structure/Unet.py +0 -152
- models/structure/Unet_2.py +0 -152
- models/structure/hf_compatible_model.py +0 -192
- requirements.txt +3 -7
- results/sample.png +0 -0
app.py
CHANGED
@@ -1,13 +1,11 @@
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
3 |
-
import numpy as np
|
4 |
from torchvision import transforms
|
5 |
from load_model import sample
|
6 |
import torch
|
7 |
-
import glob
|
8 |
import random
|
9 |
import os
|
10 |
-
|
11 |
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
device = "mps" if torch.backends.mps.is_available() else device
|
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
|
|
3 |
from torchvision import transforms
|
4 |
from load_model import sample
|
5 |
import torch
|
|
|
6 |
import random
|
7 |
import os
|
8 |
+
|
9 |
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
device = "mps" if torch.backends.mps.is_available() else device
|
load_model.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1 |
-
from models.structure.
|
2 |
from diffusers import DDPMScheduler
|
3 |
import torch
|
4 |
import os
|
5 |
import glob
|
6 |
-
from tqdm import tqdm
|
7 |
from torchvision import transforms
|
8 |
import pathlib
|
9 |
-
from torchvision.utils import save_image
|
10 |
from safetensors.torch import load_model, save_model
|
11 |
import time as tm
|
12 |
|
|
|
1 |
+
from models.structure.Advanced_Conditional_Unet import Unet
|
2 |
from diffusers import DDPMScheduler
|
3 |
import torch
|
4 |
import os
|
5 |
import glob
|
|
|
6 |
from torchvision import transforms
|
7 |
import pathlib
|
|
|
8 |
from safetensors.torch import load_model, save_model
|
9 |
import time as tm
|
10 |
|
models/structure/{Unet_3.py → Advanced_Conditional_Unet.py}
RENAMED
@@ -1,14 +1,8 @@
|
|
1 |
-
import math
|
2 |
-
from inspect import isfunction
|
3 |
from functools import partial
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
from tqdm.auto import tqdm
|
6 |
-
from einops import rearrange
|
7 |
import torch
|
8 |
-
from torch import nn
|
9 |
import torch.nn.functional as F
|
10 |
-
from .
|
11 |
-
from transformers import PreTrainedModel
|
12 |
|
13 |
|
14 |
class Unet(nn.Module):
|
|
|
|
|
|
|
1 |
from functools import partial
|
|
|
|
|
|
|
2 |
import torch
|
3 |
+
from torch import nn
|
4 |
import torch.nn.functional as F
|
5 |
+
from .Advanced_Network_Helpers import *
|
|
|
6 |
|
7 |
|
8 |
class Unet(nn.Module):
|
models/structure/Advanced_Network_Helpers.py
CHANGED
@@ -143,23 +143,13 @@ class Attention(nn.Module):
|
|
143 |
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
144 |
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
145 |
|
146 |
-
def forward(self, x
|
147 |
b, c, h, w = x.shape
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
k_att = self.to_k(cross_attend)
|
154 |
-
v_att = self.to_v(cross_attend)
|
155 |
-
q = rearrange(q_att, "b (h c) x y -> b h c (x y)", h=self.heads)
|
156 |
-
k = rearrange(k_att, "b (h c) x y -> b h c (x y)", h=self.heads)
|
157 |
-
v = rearrange(v_att, "b (h c) x y -> b h c (x y)", h=self.heads)
|
158 |
-
else:
|
159 |
-
qkv = self.to_qkv(x).chunk(3, dim=1)
|
160 |
-
q, k, v = map(
|
161 |
-
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
162 |
-
)
|
163 |
q = q * self.scale
|
164 |
|
165 |
sim = einsum("b h d i, b h d j -> b h i j", q, k)
|
@@ -173,7 +163,7 @@ class Attention(nn.Module):
|
|
173 |
|
174 |
|
175 |
class LinearCrossAttention(nn.Module):
|
176 |
-
def __init__(self, dim, heads=
|
177 |
super().__init__()
|
178 |
self.scale = dim_head**-0.5
|
179 |
self.heads = heads
|
@@ -210,25 +200,12 @@ class LinearAttention(nn.Module):
|
|
210 |
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
211 |
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
|
212 |
|
213 |
-
def forward(self, x
|
214 |
b, c, h, w = x.shape
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
q_att = self.to_q(x)
|
221 |
-
k_att = self.to_k(cross_attend)
|
222 |
-
v_att = self.to_v(cross_attend)
|
223 |
-
q = rearrange(q_att, "b (h c) x y -> b h c (x y)", h=self.heads)
|
224 |
-
k = rearrange(k_att, "b (h c) x y -> b h c (x y)", h=self.heads)
|
225 |
-
v = rearrange(v_att, "b (h c) x y -> b h c (x y)", h=self.heads)
|
226 |
-
|
227 |
-
else:
|
228 |
-
qkv = self.to_qkv(x).chunk(3, dim=1)
|
229 |
-
q, k, v = map(
|
230 |
-
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
231 |
-
)
|
232 |
# calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
|
233 |
q = q.softmax(dim=-2)
|
234 |
# calculate the softmax with respect to rows of k
|
|
|
143 |
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
144 |
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
145 |
|
146 |
+
def forward(self, x):
|
147 |
b, c, h, w = x.shape
|
148 |
|
149 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
150 |
+
q, k, v = map(
|
151 |
+
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
152 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
q = q * self.scale
|
154 |
|
155 |
sim = einsum("b h d i, b h d j -> b h i j", q, k)
|
|
|
163 |
|
164 |
|
165 |
class LinearCrossAttention(nn.Module):
|
166 |
+
def __init__(self, dim, heads=4, dim_head=32) -> None:
|
167 |
super().__init__()
|
168 |
self.scale = dim_head**-0.5
|
169 |
self.heads = heads
|
|
|
200 |
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
201 |
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
|
202 |
|
203 |
+
def forward(self, x):
|
204 |
b, c, h, w = x.shape
|
205 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
206 |
+
q, k, v = map(
|
207 |
+
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
208 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
# calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
|
210 |
q = q.softmax(dim=-2)
|
211 |
# calculate the softmax with respect to rows of k
|
models/structure/Advanced_Network_Helpers_2.py
DELETED
@@ -1,232 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from inspect import isfunction
|
3 |
-
from functools import partial
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
from tqdm.auto import tqdm
|
6 |
-
from einops import rearrange
|
7 |
-
import torch
|
8 |
-
from torch import nn, einsum
|
9 |
-
import torch.nn.functional as F
|
10 |
-
|
11 |
-
|
12 |
-
def exists(x):
|
13 |
-
return x is not None
|
14 |
-
|
15 |
-
|
16 |
-
def default(val, d):
|
17 |
-
if exists(val):
|
18 |
-
return val
|
19 |
-
return d() if isfunction(d) else d
|
20 |
-
|
21 |
-
|
22 |
-
class Residual(nn.Module):
|
23 |
-
def __init__(self, fn):
|
24 |
-
super().__init__()
|
25 |
-
self.fn = fn
|
26 |
-
|
27 |
-
def forward(self, x, *args, **kwargs):
|
28 |
-
return self.fn(x, *args, **kwargs) + x
|
29 |
-
|
30 |
-
|
31 |
-
def Upsample(dim):
|
32 |
-
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
33 |
-
|
34 |
-
|
35 |
-
def Downsample(dim):
|
36 |
-
return nn.Conv2d(dim, dim, 4, 2, 1)
|
37 |
-
|
38 |
-
|
39 |
-
class SinusoidalPositionEmbeddings(nn.Module):
|
40 |
-
def __init__(self, dim):
|
41 |
-
super().__init__()
|
42 |
-
self.dim = dim
|
43 |
-
|
44 |
-
def forward(self, time):
|
45 |
-
device = time.device
|
46 |
-
half_dim = self.dim // 2
|
47 |
-
embeddings = math.log(10000) / (half_dim - 1)
|
48 |
-
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
49 |
-
embeddings = time[:, None] * embeddings[None, :]
|
50 |
-
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
|
51 |
-
return embeddings
|
52 |
-
|
53 |
-
|
54 |
-
class Block(nn.Module):
|
55 |
-
def __init__(self, dim, dim_out, groups=8):
|
56 |
-
super().__init__()
|
57 |
-
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
|
58 |
-
self.norm = nn.GroupNorm(groups, dim_out)
|
59 |
-
self.act = nn.SiLU()
|
60 |
-
|
61 |
-
def forward(self, x, scale_shift=None):
|
62 |
-
x = self.proj(x)
|
63 |
-
x = self.norm(x)
|
64 |
-
|
65 |
-
if exists(scale_shift):
|
66 |
-
scale, shift = scale_shift
|
67 |
-
x = x * (scale + 1) + shift
|
68 |
-
|
69 |
-
x = self.act(x)
|
70 |
-
return x
|
71 |
-
|
72 |
-
|
73 |
-
class ResnetBlock(nn.Module):
|
74 |
-
"""https://arxiv.org/abs/1512.03385"""
|
75 |
-
|
76 |
-
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
|
77 |
-
super().__init__()
|
78 |
-
self.mlp = (
|
79 |
-
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
|
80 |
-
if exists(time_emb_dim)
|
81 |
-
else None
|
82 |
-
)
|
83 |
-
|
84 |
-
self.block1 = Block(dim, dim_out, groups=groups)
|
85 |
-
self.block2 = Block(dim_out, dim_out, groups=groups)
|
86 |
-
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
87 |
-
|
88 |
-
def forward(self, x, time_emb=None):
|
89 |
-
h = self.block1(x)
|
90 |
-
|
91 |
-
if exists(self.mlp) and exists(time_emb):
|
92 |
-
time_emb = self.mlp(time_emb)
|
93 |
-
h = rearrange(time_emb, "b c -> b c 1 1") + h
|
94 |
-
|
95 |
-
h = self.block2(h)
|
96 |
-
return h + self.res_conv(x)
|
97 |
-
|
98 |
-
|
99 |
-
class ConvNextBlock(nn.Module):
|
100 |
-
"""https://arxiv.org/abs/2201.03545"""
|
101 |
-
|
102 |
-
def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
|
103 |
-
super().__init__()
|
104 |
-
self.mlp = (
|
105 |
-
nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
|
106 |
-
if exists(time_emb_dim)
|
107 |
-
else None
|
108 |
-
)
|
109 |
-
|
110 |
-
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
|
111 |
-
|
112 |
-
self.net = nn.Sequential(
|
113 |
-
nn.GroupNorm(1, dim) if norm else nn.Identity(),
|
114 |
-
nn.Conv2d(dim, dim_out * mult, 3, padding=1),
|
115 |
-
nn.GELU(),
|
116 |
-
nn.GroupNorm(1, dim_out * mult),
|
117 |
-
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
|
118 |
-
)
|
119 |
-
|
120 |
-
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
121 |
-
|
122 |
-
def forward(self, x, time_emb=None):
|
123 |
-
h = self.ds_conv(x)
|
124 |
-
|
125 |
-
if exists(self.mlp) and exists(time_emb):
|
126 |
-
assert exists(time_emb), "time embedding must be passed in"
|
127 |
-
condition = self.mlp(time_emb)
|
128 |
-
h = h + rearrange(condition, "b c -> b c 1 1")
|
129 |
-
|
130 |
-
h = self.net(h)
|
131 |
-
return h + self.res_conv(x)
|
132 |
-
|
133 |
-
|
134 |
-
class Attention(nn.Module):
|
135 |
-
def __init__(self, dim, heads=4, dim_head=32):
|
136 |
-
super().__init__()
|
137 |
-
self.scale = dim_head**-0.5
|
138 |
-
self.heads = heads
|
139 |
-
hidden_dim = dim_head * heads
|
140 |
-
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
141 |
-
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
142 |
-
self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
143 |
-
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
144 |
-
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
145 |
-
|
146 |
-
def forward(self, x):
|
147 |
-
b, c, h, w = x.shape
|
148 |
-
|
149 |
-
qkv = self.to_qkv(x).chunk(3, dim=1)
|
150 |
-
q, k, v = map(
|
151 |
-
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
152 |
-
)
|
153 |
-
q = q * self.scale
|
154 |
-
|
155 |
-
sim = einsum("b h d i, b h d j -> b h i j", q, k)
|
156 |
-
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
157 |
-
attn = sim.softmax(dim=-1)
|
158 |
-
|
159 |
-
out = einsum("b h i j, b h d j -> b h i d", attn, v)
|
160 |
-
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
161 |
-
|
162 |
-
return self.to_out(out)
|
163 |
-
|
164 |
-
|
165 |
-
class LinearCrossAttention(nn.Module):
|
166 |
-
def __init__(self, dim, heads=4, dim_head=32) -> None:
|
167 |
-
super().__init__()
|
168 |
-
self.scale = dim_head**-0.5
|
169 |
-
self.heads = heads
|
170 |
-
hidden_dim = dim_head * heads
|
171 |
-
self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias=False)
|
172 |
-
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
173 |
-
self.out = nn.Conv2d(hidden_dim, dim, 1)
|
174 |
-
|
175 |
-
def forward(self, x, cross_attend):
|
176 |
-
b, c, h, w = x.shape
|
177 |
-
q = self.to_q(x)
|
178 |
-
k, v = self.to_kv(cross_attend).chunk(2, dim=1)
|
179 |
-
q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads)
|
180 |
-
k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads)
|
181 |
-
v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads)
|
182 |
-
q = q * self.scale
|
183 |
-
sim = einsum("b h d i, b h d j -> b h i j", q, k)
|
184 |
-
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
185 |
-
attn = sim.softmax(dim=-1)
|
186 |
-
out = einsum("b h i j, b h d j -> b h i d", attn, v)
|
187 |
-
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
188 |
-
return self.out(out)
|
189 |
-
|
190 |
-
|
191 |
-
class LinearAttention(nn.Module):
|
192 |
-
def __init__(self, dim, heads=4, dim_head=32):
|
193 |
-
super().__init__()
|
194 |
-
self.scale = dim_head**-0.5
|
195 |
-
self.heads = heads
|
196 |
-
hidden_dim = dim_head * heads
|
197 |
-
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
198 |
-
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
199 |
-
self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
200 |
-
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
201 |
-
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
|
202 |
-
|
203 |
-
def forward(self, x):
|
204 |
-
b, c, h, w = x.shape
|
205 |
-
qkv = self.to_qkv(x).chunk(3, dim=1)
|
206 |
-
q, k, v = map(
|
207 |
-
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
208 |
-
)
|
209 |
-
# calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
|
210 |
-
q = q.softmax(dim=-2)
|
211 |
-
# calculate the softmax with respect to rows of k
|
212 |
-
k = k.softmax(dim=-1)
|
213 |
-
# normalize the values in the attention matrix
|
214 |
-
q = q * self.scale
|
215 |
-
# dot product of q and v matrices
|
216 |
-
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
|
217 |
-
# dot product of context and q
|
218 |
-
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
|
219 |
-
# rearrange the output to match the pytorch convention
|
220 |
-
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
|
221 |
-
return self.to_out(out)
|
222 |
-
|
223 |
-
|
224 |
-
class PreNorm(nn.Module):
|
225 |
-
def __init__(self, dim, fn):
|
226 |
-
super().__init__()
|
227 |
-
self.fn = fn
|
228 |
-
self.norm = nn.GroupNorm(1, dim)
|
229 |
-
|
230 |
-
def forward(self, x, *args, **kwargs):
|
231 |
-
x = self.norm(x)
|
232 |
-
return self.fn(x, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/structure/Advanced_Network_Helpers_3.py
DELETED
@@ -1,232 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from inspect import isfunction
|
3 |
-
from functools import partial
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
from tqdm.auto import tqdm
|
6 |
-
from einops import rearrange
|
7 |
-
import torch
|
8 |
-
from torch import nn, einsum
|
9 |
-
import torch.nn.functional as F
|
10 |
-
|
11 |
-
|
12 |
-
def exists(x):
|
13 |
-
return x is not None
|
14 |
-
|
15 |
-
|
16 |
-
def default(val, d):
|
17 |
-
if exists(val):
|
18 |
-
return val
|
19 |
-
return d() if isfunction(d) else d
|
20 |
-
|
21 |
-
|
22 |
-
class Residual(nn.Module):
|
23 |
-
def __init__(self, fn):
|
24 |
-
super().__init__()
|
25 |
-
self.fn = fn
|
26 |
-
|
27 |
-
def forward(self, x, *args, **kwargs):
|
28 |
-
return self.fn(x, *args, **kwargs) + x
|
29 |
-
|
30 |
-
|
31 |
-
def Upsample(dim):
|
32 |
-
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
33 |
-
|
34 |
-
|
35 |
-
def Downsample(dim):
|
36 |
-
return nn.Conv2d(dim, dim, 4, 2, 1)
|
37 |
-
|
38 |
-
|
39 |
-
class SinusoidalPositionEmbeddings(nn.Module):
|
40 |
-
def __init__(self, dim):
|
41 |
-
super().__init__()
|
42 |
-
self.dim = dim
|
43 |
-
|
44 |
-
def forward(self, time):
|
45 |
-
device = time.device
|
46 |
-
half_dim = self.dim // 2
|
47 |
-
embeddings = math.log(10000) / (half_dim - 1)
|
48 |
-
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
49 |
-
embeddings = time[:, None] * embeddings[None, :]
|
50 |
-
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
|
51 |
-
return embeddings
|
52 |
-
|
53 |
-
|
54 |
-
class Block(nn.Module):
|
55 |
-
def __init__(self, dim, dim_out, groups=8):
|
56 |
-
super().__init__()
|
57 |
-
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
|
58 |
-
self.norm = nn.GroupNorm(groups, dim_out)
|
59 |
-
self.act = nn.SiLU()
|
60 |
-
|
61 |
-
def forward(self, x, scale_shift=None):
|
62 |
-
x = self.proj(x)
|
63 |
-
x = self.norm(x)
|
64 |
-
|
65 |
-
if exists(scale_shift):
|
66 |
-
scale, shift = scale_shift
|
67 |
-
x = x * (scale + 1) + shift
|
68 |
-
|
69 |
-
x = self.act(x)
|
70 |
-
return x
|
71 |
-
|
72 |
-
|
73 |
-
class ResnetBlock(nn.Module):
|
74 |
-
"""https://arxiv.org/abs/1512.03385"""
|
75 |
-
|
76 |
-
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
|
77 |
-
super().__init__()
|
78 |
-
self.mlp = (
|
79 |
-
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
|
80 |
-
if exists(time_emb_dim)
|
81 |
-
else None
|
82 |
-
)
|
83 |
-
|
84 |
-
self.block1 = Block(dim, dim_out, groups=groups)
|
85 |
-
self.block2 = Block(dim_out, dim_out, groups=groups)
|
86 |
-
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
87 |
-
|
88 |
-
def forward(self, x, time_emb=None):
|
89 |
-
h = self.block1(x)
|
90 |
-
|
91 |
-
if exists(self.mlp) and exists(time_emb):
|
92 |
-
time_emb = self.mlp(time_emb)
|
93 |
-
h = rearrange(time_emb, "b c -> b c 1 1") + h
|
94 |
-
|
95 |
-
h = self.block2(h)
|
96 |
-
return h + self.res_conv(x)
|
97 |
-
|
98 |
-
|
99 |
-
class ConvNextBlock(nn.Module):
|
100 |
-
"""https://arxiv.org/abs/2201.03545"""
|
101 |
-
|
102 |
-
def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
|
103 |
-
super().__init__()
|
104 |
-
self.mlp = (
|
105 |
-
nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
|
106 |
-
if exists(time_emb_dim)
|
107 |
-
else None
|
108 |
-
)
|
109 |
-
|
110 |
-
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
|
111 |
-
|
112 |
-
self.net = nn.Sequential(
|
113 |
-
nn.GroupNorm(1, dim) if norm else nn.Identity(),
|
114 |
-
nn.Conv2d(dim, dim_out * mult, 3, padding=1),
|
115 |
-
nn.GELU(),
|
116 |
-
nn.GroupNorm(1, dim_out * mult),
|
117 |
-
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
|
118 |
-
)
|
119 |
-
|
120 |
-
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
121 |
-
|
122 |
-
def forward(self, x, time_emb=None):
|
123 |
-
h = self.ds_conv(x)
|
124 |
-
|
125 |
-
if exists(self.mlp) and exists(time_emb):
|
126 |
-
assert exists(time_emb), "time embedding must be passed in"
|
127 |
-
condition = self.mlp(time_emb)
|
128 |
-
h = h + rearrange(condition, "b c -> b c 1 1")
|
129 |
-
|
130 |
-
h = self.net(h)
|
131 |
-
return h + self.res_conv(x)
|
132 |
-
|
133 |
-
|
134 |
-
class Attention(nn.Module):
|
135 |
-
def __init__(self, dim, heads=4, dim_head=32):
|
136 |
-
super().__init__()
|
137 |
-
self.scale = dim_head**-0.5
|
138 |
-
self.heads = heads
|
139 |
-
hidden_dim = dim_head * heads
|
140 |
-
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
141 |
-
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
142 |
-
self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
143 |
-
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
144 |
-
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
145 |
-
|
146 |
-
def forward(self, x):
|
147 |
-
b, c, h, w = x.shape
|
148 |
-
|
149 |
-
qkv = self.to_qkv(x).chunk(3, dim=1)
|
150 |
-
q, k, v = map(
|
151 |
-
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
152 |
-
)
|
153 |
-
q = q * self.scale
|
154 |
-
|
155 |
-
sim = einsum("b h d i, b h d j -> b h i j", q, k)
|
156 |
-
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
157 |
-
attn = sim.softmax(dim=-1)
|
158 |
-
|
159 |
-
out = einsum("b h i j, b h d j -> b h i d", attn, v)
|
160 |
-
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
161 |
-
|
162 |
-
return self.to_out(out)
|
163 |
-
|
164 |
-
|
165 |
-
class LinearCrossAttention(nn.Module):
|
166 |
-
def __init__(self, dim, heads=4, dim_head=32) -> None:
|
167 |
-
super().__init__()
|
168 |
-
self.scale = dim_head**-0.5
|
169 |
-
self.heads = heads
|
170 |
-
hidden_dim = dim_head * heads
|
171 |
-
self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias=False)
|
172 |
-
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
173 |
-
self.out = nn.Conv2d(hidden_dim, dim, 1)
|
174 |
-
|
175 |
-
def forward(self, x, cross_attend):
|
176 |
-
b, c, h, w = x.shape
|
177 |
-
q = self.to_q(x)
|
178 |
-
k, v = self.to_kv(cross_attend).chunk(2, dim=1)
|
179 |
-
q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads)
|
180 |
-
k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads)
|
181 |
-
v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads)
|
182 |
-
q = q * self.scale
|
183 |
-
sim = einsum("b h d i, b h d j -> b h i j", q, k)
|
184 |
-
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
185 |
-
attn = sim.softmax(dim=-1)
|
186 |
-
out = einsum("b h i j, b h d j -> b h i d", attn, v)
|
187 |
-
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
188 |
-
return self.out(out)
|
189 |
-
|
190 |
-
|
191 |
-
class LinearAttention(nn.Module):
|
192 |
-
def __init__(self, dim, heads=4, dim_head=32):
|
193 |
-
super().__init__()
|
194 |
-
self.scale = dim_head**-0.5
|
195 |
-
self.heads = heads
|
196 |
-
hidden_dim = dim_head * heads
|
197 |
-
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
198 |
-
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
199 |
-
self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
200 |
-
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
|
201 |
-
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
|
202 |
-
|
203 |
-
def forward(self, x):
|
204 |
-
b, c, h, w = x.shape
|
205 |
-
qkv = self.to_qkv(x).chunk(3, dim=1)
|
206 |
-
q, k, v = map(
|
207 |
-
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
208 |
-
)
|
209 |
-
# calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
|
210 |
-
q = q.softmax(dim=-2)
|
211 |
-
# calculate the softmax with respect to rows of k
|
212 |
-
k = k.softmax(dim=-1)
|
213 |
-
# normalize the values in the attention matrix
|
214 |
-
q = q * self.scale
|
215 |
-
# dot product of q and v matrices
|
216 |
-
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
|
217 |
-
# dot product of context and q
|
218 |
-
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
|
219 |
-
# rearrange the output to match the pytorch convention
|
220 |
-
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
|
221 |
-
return self.to_out(out)
|
222 |
-
|
223 |
-
|
224 |
-
class PreNorm(nn.Module):
|
225 |
-
def __init__(self, dim, fn):
|
226 |
-
super().__init__()
|
227 |
-
self.fn = fn
|
228 |
-
self.norm = nn.GroupNorm(1, dim)
|
229 |
-
|
230 |
-
def forward(self, x, *args, **kwargs):
|
231 |
-
x = self.norm(x)
|
232 |
-
return self.fn(x, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/structure/Unet.py
DELETED
@@ -1,152 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from inspect import isfunction
|
3 |
-
from functools import partial
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
from tqdm.auto import tqdm
|
6 |
-
from einops import rearrange
|
7 |
-
import torch
|
8 |
-
from torch import nn, einsum
|
9 |
-
import torch.nn.functional as F
|
10 |
-
from .Advanced_Network_Helpers import *
|
11 |
-
|
12 |
-
|
13 |
-
class Unet(nn.Module):
|
14 |
-
def __init__(
|
15 |
-
self,
|
16 |
-
dim,
|
17 |
-
init_dim=None,
|
18 |
-
out_dim=None,
|
19 |
-
dim_mults=(1, 2, 4, 8),
|
20 |
-
channels=3,
|
21 |
-
with_time_emb=True,
|
22 |
-
resnet_block_groups=8,
|
23 |
-
use_convnext=True,
|
24 |
-
convnext_mult=2,
|
25 |
-
):
|
26 |
-
super().__init__()
|
27 |
-
|
28 |
-
# determine dimensions
|
29 |
-
self.channels = channels # since we are concatenating the images and the conditionings along the channel dimension
|
30 |
-
|
31 |
-
init_dim = default(init_dim, dim // 3 * 2)
|
32 |
-
self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
|
33 |
-
self.conditioning_init = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
|
34 |
-
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
35 |
-
in_out = list(zip(dims[:-1], dims[1:]))
|
36 |
-
self.in_out = in_out
|
37 |
-
|
38 |
-
if use_convnext:
|
39 |
-
block_klass = partial(ConvNextBlock, mult=convnext_mult)
|
40 |
-
else:
|
41 |
-
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
|
42 |
-
|
43 |
-
# time embeddings
|
44 |
-
if with_time_emb:
|
45 |
-
time_dim = dim * 4
|
46 |
-
self.time_mlp = nn.Sequential(
|
47 |
-
SinusoidalPositionEmbeddings(dim),
|
48 |
-
nn.Linear(dim, time_dim),
|
49 |
-
nn.GELU(),
|
50 |
-
nn.Linear(time_dim, time_dim),
|
51 |
-
)
|
52 |
-
else:
|
53 |
-
time_dim = None
|
54 |
-
self.time_mlp = None
|
55 |
-
|
56 |
-
# layers
|
57 |
-
self.downs = nn.ModuleList([])
|
58 |
-
self.ups = nn.ModuleList([])
|
59 |
-
self.conditioning_encoder = nn.ModuleList([])
|
60 |
-
num_resolutions = len(in_out)
|
61 |
-
self.num_resolutions = num_resolutions
|
62 |
-
|
63 |
-
# conditioning encoder
|
64 |
-
for ind, (dim_in, dim_out) in enumerate(in_out):
|
65 |
-
is_last = ind >= (num_resolutions - 1)
|
66 |
-
|
67 |
-
self.conditioning_encoder.append(
|
68 |
-
nn.ModuleList(
|
69 |
-
[
|
70 |
-
block_klass(dim_in, dim_out),
|
71 |
-
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
72 |
-
Downsample(dim_out) if not is_last else nn.Identity(),
|
73 |
-
]
|
74 |
-
)
|
75 |
-
)
|
76 |
-
|
77 |
-
for ind, (dim_in, dim_out) in enumerate(in_out):
|
78 |
-
is_last = ind >= (num_resolutions - 1)
|
79 |
-
|
80 |
-
self.downs.append(
|
81 |
-
nn.ModuleList(
|
82 |
-
[
|
83 |
-
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
|
84 |
-
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
|
85 |
-
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
86 |
-
Downsample(dim_out) if not is_last else nn.Identity(),
|
87 |
-
]
|
88 |
-
)
|
89 |
-
)
|
90 |
-
|
91 |
-
mid_dim = dims[-1]
|
92 |
-
|
93 |
-
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
94 |
-
self.cross_attention = Residual(PreNorm(mid_dim, LinearCrossAttention(mid_dim)))
|
95 |
-
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
96 |
-
|
97 |
-
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
98 |
-
is_last = ind >= (num_resolutions - 1)
|
99 |
-
self.ups.append(
|
100 |
-
nn.ModuleList(
|
101 |
-
[
|
102 |
-
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
|
103 |
-
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
|
104 |
-
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
|
105 |
-
Upsample(dim_in) if not is_last else nn.Identity(),
|
106 |
-
]
|
107 |
-
)
|
108 |
-
)
|
109 |
-
|
110 |
-
out_dim = default(out_dim, channels)
|
111 |
-
self.final_conv = nn.Sequential(
|
112 |
-
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
|
113 |
-
)
|
114 |
-
|
115 |
-
def forward(self, x, time, implicit_conditioning, explicit_conditioning):
|
116 |
-
x = torch.cat((x, explicit_conditioning), dim=1)
|
117 |
-
conditioning = torch.cat((implicit_conditioning, explicit_conditioning), dim=1)
|
118 |
-
x = self.init_conv(x)
|
119 |
-
|
120 |
-
conditioning = self.conditioning_init(conditioning)
|
121 |
-
|
122 |
-
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
123 |
-
|
124 |
-
h = []
|
125 |
-
|
126 |
-
# conditioning encoder
|
127 |
-
|
128 |
-
for block1, attn, downsample in self.conditioning_encoder:
|
129 |
-
conditioning = block1(conditioning)
|
130 |
-
conditioning = attn(conditioning)
|
131 |
-
conditioning = downsample(conditioning)
|
132 |
-
|
133 |
-
for block1, block2, attn, downsample in self.downs:
|
134 |
-
x = block1(x, t)
|
135 |
-
x = block2(x, t)
|
136 |
-
x = attn(x)
|
137 |
-
h.append(x)
|
138 |
-
x = downsample(x)
|
139 |
-
|
140 |
-
# bottleneck
|
141 |
-
x = self.mid_block1(x, t)
|
142 |
-
x = self.cross_attention(x, conditioning)
|
143 |
-
x = self.mid_block2(x, t)
|
144 |
-
|
145 |
-
for block1, block2, attn, upsample in self.ups:
|
146 |
-
x = torch.cat((x, h.pop()), dim=1)
|
147 |
-
x = block1(x, t)
|
148 |
-
x = block2(x, t)
|
149 |
-
x = attn(x)
|
150 |
-
x = upsample(x)
|
151 |
-
|
152 |
-
return self.final_conv(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/structure/Unet_2.py
DELETED
@@ -1,152 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from inspect import isfunction
|
3 |
-
from functools import partial
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
from tqdm.auto import tqdm
|
6 |
-
from einops import rearrange
|
7 |
-
import torch
|
8 |
-
from torch import nn, einsum
|
9 |
-
import torch.nn.functional as F
|
10 |
-
from .Advanced_Network_Helpers_2 import *
|
11 |
-
|
12 |
-
|
13 |
-
class Unet(nn.Module):
|
14 |
-
def __init__(
|
15 |
-
self,
|
16 |
-
dim,
|
17 |
-
init_dim=None,
|
18 |
-
out_dim=None,
|
19 |
-
dim_mults=(1, 2, 4, 8),
|
20 |
-
channels=3,
|
21 |
-
with_time_emb=True,
|
22 |
-
resnet_block_groups=8,
|
23 |
-
use_convnext=True,
|
24 |
-
convnext_mult=2,
|
25 |
-
):
|
26 |
-
super().__init__()
|
27 |
-
|
28 |
-
# determine dimensions
|
29 |
-
self.channels = channels # since we are concatenating the images and the conditionings along the channel dimension
|
30 |
-
|
31 |
-
init_dim = default(init_dim, dim // 3 * 2)
|
32 |
-
self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
|
33 |
-
self.conditioning_init = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
|
34 |
-
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
35 |
-
in_out = list(zip(dims[:-1], dims[1:]))
|
36 |
-
self.in_out = in_out
|
37 |
-
|
38 |
-
if use_convnext:
|
39 |
-
block_klass = partial(ConvNextBlock, mult=convnext_mult)
|
40 |
-
else:
|
41 |
-
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
|
42 |
-
|
43 |
-
# time embeddings
|
44 |
-
if with_time_emb:
|
45 |
-
time_dim = dim * 4
|
46 |
-
self.time_mlp = nn.Sequential(
|
47 |
-
SinusoidalPositionEmbeddings(dim),
|
48 |
-
nn.Linear(dim, time_dim),
|
49 |
-
nn.GELU(),
|
50 |
-
nn.Linear(time_dim, time_dim),
|
51 |
-
)
|
52 |
-
else:
|
53 |
-
time_dim = None
|
54 |
-
self.time_mlp = None
|
55 |
-
|
56 |
-
# layers
|
57 |
-
self.downs = nn.ModuleList([])
|
58 |
-
self.ups = nn.ModuleList([])
|
59 |
-
self.conditioning_encoder = nn.ModuleList([])
|
60 |
-
num_resolutions = len(in_out)
|
61 |
-
self.num_resolutions = num_resolutions
|
62 |
-
|
63 |
-
# conditioning encoder
|
64 |
-
for ind, (dim_in, dim_out) in enumerate(in_out):
|
65 |
-
is_last = ind >= (num_resolutions - 1)
|
66 |
-
|
67 |
-
self.conditioning_encoder.append(
|
68 |
-
nn.ModuleList(
|
69 |
-
[
|
70 |
-
block_klass(dim_in, dim_out),
|
71 |
-
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
72 |
-
Downsample(dim_out) if not is_last else nn.Identity(),
|
73 |
-
]
|
74 |
-
)
|
75 |
-
)
|
76 |
-
|
77 |
-
for ind, (dim_in, dim_out) in enumerate(in_out):
|
78 |
-
is_last = ind >= (num_resolutions - 1)
|
79 |
-
|
80 |
-
self.downs.append(
|
81 |
-
nn.ModuleList(
|
82 |
-
[
|
83 |
-
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
|
84 |
-
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
|
85 |
-
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
86 |
-
Downsample(dim_out) if not is_last else nn.Identity(),
|
87 |
-
]
|
88 |
-
)
|
89 |
-
)
|
90 |
-
|
91 |
-
mid_dim = dims[-1]
|
92 |
-
|
93 |
-
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
94 |
-
self.cross_attention = Residual(PreNorm(mid_dim, LinearCrossAttention(mid_dim)))
|
95 |
-
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
96 |
-
|
97 |
-
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
98 |
-
is_last = ind >= (num_resolutions - 1)
|
99 |
-
self.ups.append(
|
100 |
-
nn.ModuleList(
|
101 |
-
[
|
102 |
-
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
|
103 |
-
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
|
104 |
-
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
|
105 |
-
Upsample(dim_in) if not is_last else nn.Identity(),
|
106 |
-
]
|
107 |
-
)
|
108 |
-
)
|
109 |
-
|
110 |
-
out_dim = default(out_dim, channels)
|
111 |
-
self.final_conv = nn.Sequential(
|
112 |
-
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
|
113 |
-
)
|
114 |
-
|
115 |
-
def forward(self, x, time, implicit_conditioning, explicit_conditioning):
|
116 |
-
x = torch.cat((x, explicit_conditioning), dim=1)
|
117 |
-
conditioning = torch.cat((implicit_conditioning, explicit_conditioning), dim=1)
|
118 |
-
x = self.init_conv(x)
|
119 |
-
|
120 |
-
conditioning = self.conditioning_init(conditioning)
|
121 |
-
|
122 |
-
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
123 |
-
|
124 |
-
h = []
|
125 |
-
|
126 |
-
# conditioning encoder
|
127 |
-
|
128 |
-
for block1, attn, downsample in self.conditioning_encoder:
|
129 |
-
conditioning = block1(conditioning)
|
130 |
-
conditioning = attn(conditioning)
|
131 |
-
conditioning = downsample(conditioning)
|
132 |
-
|
133 |
-
for block1, block2, attn, downsample in self.downs:
|
134 |
-
x = block1(x, t)
|
135 |
-
x = block2(x, t)
|
136 |
-
x = attn(x)
|
137 |
-
h.append(x)
|
138 |
-
x = downsample(x)
|
139 |
-
|
140 |
-
# bottleneck
|
141 |
-
x = self.mid_block1(x, t)
|
142 |
-
x = self.cross_attention(x, conditioning)
|
143 |
-
x = self.mid_block2(x, t)
|
144 |
-
|
145 |
-
for block1, block2, attn, upsample in self.ups:
|
146 |
-
x = torch.cat((x, h.pop()), dim=1)
|
147 |
-
x = block1(x, t)
|
148 |
-
x = block2(x, t)
|
149 |
-
x = attn(x)
|
150 |
-
x = upsample(x)
|
151 |
-
|
152 |
-
return self.final_conv(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/structure/hf_compatible_model.py
DELETED
@@ -1,192 +0,0 @@
|
|
1 |
-
from transformers import PretrainedConfig, PreTrainedModel
|
2 |
-
import math
|
3 |
-
from inspect import isfunction
|
4 |
-
from functools import partial
|
5 |
-
import matplotlib.pyplot as plt
|
6 |
-
from tqdm.auto import tqdm
|
7 |
-
from einops import rearrange
|
8 |
-
import torch
|
9 |
-
from torch import nn, einsum
|
10 |
-
import torch.nn.functional as F
|
11 |
-
from transformers import PreTrainedModel
|
12 |
-
from .Advanced_Network_Helpers_3 import *
|
13 |
-
import os
|
14 |
-
|
15 |
-
|
16 |
-
class UnetConfig(PretrainedConfig):
|
17 |
-
model_type = "unet"
|
18 |
-
|
19 |
-
def __init__(
|
20 |
-
self,
|
21 |
-
dim=64,
|
22 |
-
init_dim=None,
|
23 |
-
out_dim=None,
|
24 |
-
dim_mults=(1, 2, 4, 8),
|
25 |
-
channels=3,
|
26 |
-
with_time_emb=True,
|
27 |
-
resnet_block_groups=8,
|
28 |
-
use_convnext=True,
|
29 |
-
convnext_mult=2,
|
30 |
-
**kwargs
|
31 |
-
):
|
32 |
-
super().__init__(**kwargs)
|
33 |
-
self.dim = dim
|
34 |
-
self.init_dim = init_dim
|
35 |
-
self.out_dim = out_dim
|
36 |
-
self.dim_mults = dim_mults
|
37 |
-
self.channels = channels
|
38 |
-
self.with_time_emb = with_time_emb
|
39 |
-
self.resnet_block_groups = resnet_block_groups
|
40 |
-
self.use_convnext = use_convnext
|
41 |
-
self.convnext_mult = convnext_mult
|
42 |
-
|
43 |
-
|
44 |
-
class Unet(PreTrainedModel):
|
45 |
-
config_class = UnetConfig
|
46 |
-
|
47 |
-
def __init__(
|
48 |
-
self,
|
49 |
-
config,
|
50 |
-
):
|
51 |
-
super().__init__(config)
|
52 |
-
|
53 |
-
# determine dimensions
|
54 |
-
self.channels = (
|
55 |
-
config.channels
|
56 |
-
) # since we are concatenating the images and the conditionings along the channel dimension
|
57 |
-
|
58 |
-
init_dim = default(config.init_dim, config.dim // 3 * 2)
|
59 |
-
self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
|
60 |
-
self.conditioning_init = nn.Conv2d(self.channels, init_dim, 7, padding=3)
|
61 |
-
dims = [init_dim, *map(lambda m: config.dim * m, config.dim_mults)]
|
62 |
-
in_out = list(zip(dims[:-1], dims[1:]))
|
63 |
-
self.in_out = in_out
|
64 |
-
|
65 |
-
if config.use_convnext:
|
66 |
-
block_klass = partial(ConvNextBlock, mult=config.convnext_mult)
|
67 |
-
else:
|
68 |
-
block_klass = partial(ResnetBlock, groups=config.resnet_block_groups)
|
69 |
-
|
70 |
-
# time embeddings
|
71 |
-
if config.with_time_emb:
|
72 |
-
time_dim = config.dim * 4
|
73 |
-
self.time_mlp = nn.Sequential(
|
74 |
-
SinusoidalPositionEmbeddings(config.dim),
|
75 |
-
nn.Linear(config.dim, time_dim),
|
76 |
-
nn.GELU(),
|
77 |
-
nn.Linear(time_dim, time_dim),
|
78 |
-
)
|
79 |
-
else:
|
80 |
-
time_dim = None
|
81 |
-
self.time_mlp = None
|
82 |
-
|
83 |
-
# layers
|
84 |
-
self.downs = nn.ModuleList([])
|
85 |
-
self.ups = nn.ModuleList([])
|
86 |
-
self.conditioning_encoder = nn.ModuleList([])
|
87 |
-
num_resolutions = len(in_out)
|
88 |
-
self.num_resolutions = num_resolutions
|
89 |
-
|
90 |
-
# conditioning encoder
|
91 |
-
for ind, (dim_in, dim_out) in enumerate(in_out):
|
92 |
-
is_last = ind >= (num_resolutions - 1)
|
93 |
-
|
94 |
-
self.conditioning_encoder.append(
|
95 |
-
nn.ModuleList(
|
96 |
-
[
|
97 |
-
block_klass(dim_in, dim_out),
|
98 |
-
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
99 |
-
Downsample(dim_out) if not is_last else nn.Identity(),
|
100 |
-
]
|
101 |
-
)
|
102 |
-
)
|
103 |
-
|
104 |
-
for ind, (dim_in, dim_out) in enumerate(in_out):
|
105 |
-
is_last = ind >= (num_resolutions - 1)
|
106 |
-
|
107 |
-
self.downs.append(
|
108 |
-
nn.ModuleList(
|
109 |
-
[
|
110 |
-
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
|
111 |
-
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
|
112 |
-
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
113 |
-
Downsample(dim_out) if not is_last else nn.Identity(),
|
114 |
-
]
|
115 |
-
)
|
116 |
-
)
|
117 |
-
|
118 |
-
mid_dim = dims[-1]
|
119 |
-
|
120 |
-
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
121 |
-
self.cross_attention_1 = Residual(
|
122 |
-
PreNorm(mid_dim, LinearCrossAttention(mid_dim))
|
123 |
-
)
|
124 |
-
self.cross_attention_2 = Residual(
|
125 |
-
PreNorm(mid_dim, LinearCrossAttention(mid_dim))
|
126 |
-
)
|
127 |
-
self.cross_attention_3 = Residual(
|
128 |
-
PreNorm(mid_dim, LinearCrossAttention(mid_dim))
|
129 |
-
)
|
130 |
-
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
131 |
-
|
132 |
-
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
133 |
-
is_last = ind >= (num_resolutions - 1)
|
134 |
-
self.ups.append(
|
135 |
-
nn.ModuleList(
|
136 |
-
[
|
137 |
-
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
|
138 |
-
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
|
139 |
-
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
|
140 |
-
Upsample(dim_in) if not is_last else nn.Identity(),
|
141 |
-
]
|
142 |
-
)
|
143 |
-
)
|
144 |
-
|
145 |
-
out_dim = default(config.out_dim, config.channels)
|
146 |
-
self.final_conv = nn.Sequential(
|
147 |
-
block_klass(config.dim, config.dim), nn.Conv2d(config.dim, out_dim, 1)
|
148 |
-
)
|
149 |
-
|
150 |
-
def forward(self, x, time, implicit_conditioning, explicit_conditioning):
|
151 |
-
x = torch.cat((x, explicit_conditioning), dim=1)
|
152 |
-
|
153 |
-
x = self.init_conv(x)
|
154 |
-
|
155 |
-
conditioning = self.conditioning_init(implicit_conditioning)
|
156 |
-
|
157 |
-
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
158 |
-
|
159 |
-
h = []
|
160 |
-
|
161 |
-
# conditioning encoder
|
162 |
-
|
163 |
-
for block1, attn, downsample in self.conditioning_encoder:
|
164 |
-
conditioning = block1(conditioning)
|
165 |
-
conditioning = attn(conditioning)
|
166 |
-
conditioning = downsample(conditioning)
|
167 |
-
|
168 |
-
for block1, block2, attn, downsample in self.downs:
|
169 |
-
x = block1(x, t)
|
170 |
-
x = block2(x, t)
|
171 |
-
x = attn(x)
|
172 |
-
h.append(x)
|
173 |
-
x = downsample(x)
|
174 |
-
|
175 |
-
# reverse the c list
|
176 |
-
|
177 |
-
# bottleneck
|
178 |
-
|
179 |
-
x = self.cross_attention_1(x, conditioning)
|
180 |
-
x = self.mid_block1(x, t)
|
181 |
-
x = self.cross_attention_2(x, conditioning)
|
182 |
-
x = self.mid_block2(x, t)
|
183 |
-
x = self.cross_attention_3(x, conditioning)
|
184 |
-
|
185 |
-
for block1, block2, attn, upsample in self.ups:
|
186 |
-
x = torch.cat((x, h.pop()), dim=1)
|
187 |
-
x = block1(x, t)
|
188 |
-
x = block2(x, t)
|
189 |
-
x = attn(x)
|
190 |
-
x = upsample(x)
|
191 |
-
|
192 |
-
return self.final_conv(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,14 +1,10 @@
|
|
1 |
einops
|
2 |
datasets
|
3 |
-
matplotlib
|
4 |
tqdm
|
5 |
accelerate
|
6 |
-
jax[cpu]
|
7 |
torchinfo
|
8 |
-
wandb
|
9 |
-
ema_pytorch
|
10 |
-
lpips
|
11 |
-
pyyaml
|
12 |
diffusers
|
13 |
transformers
|
14 |
-
|
|
|
|
|
|
1 |
einops
|
2 |
datasets
|
|
|
3 |
tqdm
|
4 |
accelerate
|
|
|
5 |
torchinfo
|
|
|
|
|
|
|
|
|
6 |
diffusers
|
7 |
transformers
|
8 |
+
pathlib
|
9 |
+
safetensors
|
10 |
+
|
results/sample.png
CHANGED
![]() |
![]() |