Spaces:
Runtime error
Runtime error
+ HF models now built with config not pickle
Browse files- app.py +8 -4
- msma.py +22 -12
- networks_edm2.py +318 -0
app.py
CHANGED
@@ -11,7 +11,12 @@ import torch
|
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
from safetensors.torch import load_file
|
13 |
|
14 |
-
from msma import
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
@cache
|
@@ -32,7 +37,6 @@ def load_model_from_hub(preset, device):
|
|
32 |
if 'DNNLIB_CACHE_DIR' in os.environ:
|
33 |
cache_dir = os.environ["DNNLIB_CACHE_DIR"]
|
34 |
|
35 |
-
scorenet = build_model_from_pickle(preset)
|
36 |
|
37 |
for fname in ['config.json', 'gmm.pkl', 'refscores.npz', 'model.safetensors' ]:
|
38 |
cached_fname = hf_hub_download(
|
@@ -49,10 +53,10 @@ def load_model_from_hub(preset, device):
|
|
49 |
print("Loaded:", model_params)
|
50 |
|
51 |
hf_checkpoint = f"{modeldir}/model.safetensors"
|
52 |
-
model =
|
53 |
model.load_state_dict(load_file(hf_checkpoint), strict=True)
|
54 |
model = model.eval().requires_grad_(False)
|
55 |
-
|
56 |
return model, modeldir
|
57 |
|
58 |
|
|
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
from safetensors.torch import load_file
|
13 |
|
14 |
+
from msma import (
|
15 |
+
ScoreFlow,
|
16 |
+
build_model_from_config,
|
17 |
+
build_model_from_pickle,
|
18 |
+
config_presets,
|
19 |
+
)
|
20 |
|
21 |
|
22 |
@cache
|
|
|
37 |
if 'DNNLIB_CACHE_DIR' in os.environ:
|
38 |
cache_dir = os.environ["DNNLIB_CACHE_DIR"]
|
39 |
|
|
|
40 |
|
41 |
for fname in ['config.json', 'gmm.pkl', 'refscores.npz', 'model.safetensors' ]:
|
42 |
cached_fname = hf_hub_download(
|
|
|
53 |
print("Loaded:", model_params)
|
54 |
|
55 |
hf_checkpoint = f"{modeldir}/model.safetensors"
|
56 |
+
model = build_model_from_config(model_params)
|
57 |
model.load_state_dict(load_file(hf_checkpoint), strict=True)
|
58 |
model = model.eval().requires_grad_(False)
|
59 |
+
model.to(device)
|
60 |
return model, modeldir
|
61 |
|
62 |
|
msma.py
CHANGED
@@ -21,6 +21,7 @@ from tqdm import tqdm
|
|
21 |
import dnnlib
|
22 |
from dataset import ImageFolderDataset
|
23 |
from flowutils import PatchFlow, sanitize_locals
|
|
|
24 |
|
25 |
DEVICE: Literal["cuda", "cpu"] = "cpu"
|
26 |
model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
|
@@ -122,6 +123,14 @@ class ScoreFlow(torch.nn.Module):
|
|
122 |
return self.flow(x_scores)
|
123 |
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
def build_model_from_pickle(preset="edm2-img64-s-fid", device="cpu"):
|
126 |
netpath = config_presets[preset]
|
127 |
with dnnlib.util.open_url(netpath, verbose=1) as f:
|
@@ -196,13 +205,13 @@ def cmdline():
|
|
196 |
def common_args(func):
|
197 |
@wraps(func)
|
198 |
@click.option(
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
)
|
206 |
@click.option(
|
207 |
"--dataset_path",
|
208 |
help="Path to the dataset",
|
@@ -222,7 +231,8 @@ def common_args(func):
|
|
222 |
|
223 |
return wrapper
|
224 |
|
225 |
-
|
|
|
226 |
@click.option(
|
227 |
"--gridsearch",
|
228 |
help="Whether to use a grid search on a number of components to find the best fit",
|
@@ -365,7 +375,7 @@ def train_flow(dataset_path, preset, outdir, epochs, batch_size, **flow_kwargs):
|
|
365 |
train_ds, batch_size=batch_size, num_workers=4, prefetch_factor=2, shuffle=True
|
366 |
)
|
367 |
testiter = torch.utils.data.DataLoader(
|
368 |
-
val_ds, batch_size=batch_size*2, num_workers=4, prefetch_factor=2
|
369 |
)
|
370 |
|
371 |
scorenet = build_model_from_pickle(preset)
|
@@ -392,10 +402,10 @@ def train_flow(dataset_path, preset, outdir, epochs, batch_size, **flow_kwargs):
|
|
392 |
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M")
|
393 |
writer = SummaryWriter(f"{experiment_dir}/logs/{timestamp}")
|
394 |
|
395 |
-
with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
|
396 |
json.dump(model.config, f, sort_keys=True, indent=4)
|
397 |
|
398 |
-
with open(f"{experiment_dir}/config.json", "w") as f:
|
399 |
json.dump(model.config, f, sort_keys=True, indent=4)
|
400 |
|
401 |
# totaliters = int(epochs * train_len)
|
@@ -463,7 +473,7 @@ def train_flow(dataset_path, preset, outdir, epochs, batch_size, **flow_kwargs):
|
|
463 |
|
464 |
# Save final model
|
465 |
torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
|
466 |
-
|
467 |
writer.close()
|
468 |
|
469 |
|
|
|
21 |
import dnnlib
|
22 |
from dataset import ImageFolderDataset
|
23 |
from flowutils import PatchFlow, sanitize_locals
|
24 |
+
from networks_edm2 import Precond
|
25 |
|
26 |
DEVICE: Literal["cuda", "cpu"] = "cpu"
|
27 |
model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
|
|
|
123 |
return self.flow(x_scores)
|
124 |
|
125 |
|
126 |
+
def build_model_from_config(model_params):
|
127 |
+
net = Precond(**model_params["EDMNet"])
|
128 |
+
scorenet = EDMScorer(net=net, **model_params["EDMScorer"])
|
129 |
+
scoreflow = ScoreFlow(scorenet=scorenet, **model_params["PatchFlow"])
|
130 |
+
print("Built model from config")
|
131 |
+
return scoreflow
|
132 |
+
|
133 |
+
|
134 |
def build_model_from_pickle(preset="edm2-img64-s-fid", device="cpu"):
|
135 |
netpath = config_presets[preset]
|
136 |
with dnnlib.util.open_url(netpath, verbose=1) as f:
|
|
|
205 |
def common_args(func):
|
206 |
@wraps(func)
|
207 |
@click.option(
|
208 |
+
"--preset",
|
209 |
+
help="Configuration preset",
|
210 |
+
metavar="STR",
|
211 |
+
type=str,
|
212 |
+
default="edm2-img64-s-fid",
|
213 |
+
show_default=True,
|
214 |
+
)
|
215 |
@click.option(
|
216 |
"--dataset_path",
|
217 |
help="Path to the dataset",
|
|
|
231 |
|
232 |
return wrapper
|
233 |
|
234 |
+
|
235 |
+
@cmdline.command("train-gmm")
|
236 |
@click.option(
|
237 |
"--gridsearch",
|
238 |
help="Whether to use a grid search on a number of components to find the best fit",
|
|
|
375 |
train_ds, batch_size=batch_size, num_workers=4, prefetch_factor=2, shuffle=True
|
376 |
)
|
377 |
testiter = torch.utils.data.DataLoader(
|
378 |
+
val_ds, batch_size=batch_size * 2, num_workers=4, prefetch_factor=2
|
379 |
)
|
380 |
|
381 |
scorenet = build_model_from_pickle(preset)
|
|
|
402 |
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M")
|
403 |
writer = SummaryWriter(f"{experiment_dir}/logs/{timestamp}")
|
404 |
|
405 |
+
with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
|
406 |
json.dump(model.config, f, sort_keys=True, indent=4)
|
407 |
|
408 |
+
with open(f"{experiment_dir}/config.json", "w") as f:
|
409 |
json.dump(model.config, f, sort_keys=True, indent=4)
|
410 |
|
411 |
# totaliters = int(epochs * train_len)
|
|
|
473 |
|
474 |
# Save final model
|
475 |
torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
|
476 |
+
|
477 |
writer.close()
|
478 |
|
479 |
|
networks_edm2.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under a Creative Commons
|
4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
5 |
+
# You should have received a copy of the license along with this
|
6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
7 |
+
|
8 |
+
"""Improved diffusion model architecture proposed in the paper
|
9 |
+
"Analyzing and Improving the Training Dynamics of Diffusion Models"."""
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from torch_utils import persistence
|
14 |
+
from torch_utils import misc
|
15 |
+
|
16 |
+
#----------------------------------------------------------------------------
|
17 |
+
# Normalize given tensor to unit magnitude with respect to the given
|
18 |
+
# dimensions. Default = all dimensions except the first.
|
19 |
+
|
20 |
+
def normalize(x, dim=None, eps=1e-4):
|
21 |
+
if dim is None:
|
22 |
+
dim = list(range(1, x.ndim))
|
23 |
+
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
24 |
+
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
|
25 |
+
return x / norm.to(x.dtype)
|
26 |
+
|
27 |
+
#----------------------------------------------------------------------------
|
28 |
+
# Upsample or downsample the given tensor with the given filter,
|
29 |
+
# or keep it as is.
|
30 |
+
|
31 |
+
def resample(x, f=[1,1], mode='keep'):
|
32 |
+
if mode == 'keep':
|
33 |
+
return x
|
34 |
+
f = np.float32(f)
|
35 |
+
assert f.ndim == 1 and len(f) % 2 == 0
|
36 |
+
pad = (len(f) - 1) // 2
|
37 |
+
f = f / f.sum()
|
38 |
+
f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
|
39 |
+
f = misc.const_like(x, f)
|
40 |
+
c = x.shape[1]
|
41 |
+
if mode == 'down':
|
42 |
+
return torch.nn.functional.conv2d(x, f.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
43 |
+
assert mode == 'up'
|
44 |
+
return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
45 |
+
|
46 |
+
#----------------------------------------------------------------------------
|
47 |
+
# Magnitude-preserving SiLU (Equation 81).
|
48 |
+
|
49 |
+
def mp_silu(x):
|
50 |
+
return torch.nn.functional.silu(x) / 0.596
|
51 |
+
|
52 |
+
#----------------------------------------------------------------------------
|
53 |
+
# Magnitude-preserving sum (Equation 88).
|
54 |
+
|
55 |
+
def mp_sum(a, b, t=0.5):
|
56 |
+
return a.lerp(b, t) / np.sqrt((1 - t) ** 2 + t ** 2)
|
57 |
+
|
58 |
+
#----------------------------------------------------------------------------
|
59 |
+
# Magnitude-preserving concatenation (Equation 103).
|
60 |
+
|
61 |
+
def mp_cat(a, b, dim=1, t=0.5):
|
62 |
+
Na = a.shape[dim]
|
63 |
+
Nb = b.shape[dim]
|
64 |
+
C = np.sqrt((Na + Nb) / ((1 - t) ** 2 + t ** 2))
|
65 |
+
wa = C / np.sqrt(Na) * (1 - t)
|
66 |
+
wb = C / np.sqrt(Nb) * t
|
67 |
+
return torch.cat([wa * a , wb * b], dim=dim)
|
68 |
+
|
69 |
+
#----------------------------------------------------------------------------
|
70 |
+
# Magnitude-preserving Fourier features (Equation 75).
|
71 |
+
|
72 |
+
@persistence.persistent_class
|
73 |
+
class MPFourier(torch.nn.Module):
|
74 |
+
def __init__(self, num_channels, bandwidth=1):
|
75 |
+
super().__init__()
|
76 |
+
self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
|
77 |
+
self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
y = x.to(torch.float32)
|
81 |
+
y = y.ger(self.freqs.to(torch.float32))
|
82 |
+
y = y + self.phases.to(torch.float32)
|
83 |
+
y = y.cos() * np.sqrt(2)
|
84 |
+
return y.to(x.dtype)
|
85 |
+
|
86 |
+
#----------------------------------------------------------------------------
|
87 |
+
# Magnitude-preserving convolution or fully-connected layer (Equation 47)
|
88 |
+
# with force weight normalization (Equation 66).
|
89 |
+
|
90 |
+
@persistence.persistent_class
|
91 |
+
class MPConv(torch.nn.Module):
|
92 |
+
def __init__(self, in_channels, out_channels, kernel):
|
93 |
+
super().__init__()
|
94 |
+
self.out_channels = out_channels
|
95 |
+
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
|
96 |
+
|
97 |
+
def forward(self, x, gain=1):
|
98 |
+
w = self.weight.to(torch.float32)
|
99 |
+
if self.training:
|
100 |
+
with torch.no_grad():
|
101 |
+
self.weight.copy_(normalize(w)) # forced weight normalization
|
102 |
+
w = normalize(w) # traditional weight normalization
|
103 |
+
w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
|
104 |
+
w = w.to(x.dtype)
|
105 |
+
if w.ndim == 2:
|
106 |
+
return x @ w.t()
|
107 |
+
assert w.ndim == 4
|
108 |
+
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
|
109 |
+
|
110 |
+
#----------------------------------------------------------------------------
|
111 |
+
# U-Net encoder/decoder block with optional self-attention (Figure 21).
|
112 |
+
|
113 |
+
@persistence.persistent_class
|
114 |
+
class Block(torch.nn.Module):
|
115 |
+
def __init__(self,
|
116 |
+
in_channels, # Number of input channels.
|
117 |
+
out_channels, # Number of output channels.
|
118 |
+
emb_channels, # Number of embedding channels.
|
119 |
+
flavor = 'enc', # Flavor: 'enc' or 'dec'.
|
120 |
+
resample_mode = 'keep', # Resampling: 'keep', 'up', or 'down'.
|
121 |
+
resample_filter = [1,1], # Resampling filter.
|
122 |
+
attention = False, # Include self-attention?
|
123 |
+
channels_per_head = 64, # Number of channels per attention head.
|
124 |
+
dropout = 0, # Dropout probability.
|
125 |
+
res_balance = 0.3, # Balance between main branch (0) and residual branch (1).
|
126 |
+
attn_balance = 0.3, # Balance between main branch (0) and self-attention (1).
|
127 |
+
clip_act = 256, # Clip output activations. None = do not clip.
|
128 |
+
):
|
129 |
+
super().__init__()
|
130 |
+
self.out_channels = out_channels
|
131 |
+
self.flavor = flavor
|
132 |
+
self.resample_filter = resample_filter
|
133 |
+
self.resample_mode = resample_mode
|
134 |
+
self.num_heads = out_channels // channels_per_head if attention else 0
|
135 |
+
self.dropout = dropout
|
136 |
+
self.res_balance = res_balance
|
137 |
+
self.attn_balance = attn_balance
|
138 |
+
self.clip_act = clip_act
|
139 |
+
self.emb_gain = torch.nn.Parameter(torch.zeros([]))
|
140 |
+
self.conv_res0 = MPConv(out_channels if flavor == 'enc' else in_channels, out_channels, kernel=[3,3])
|
141 |
+
self.emb_linear = MPConv(emb_channels, out_channels, kernel=[])
|
142 |
+
self.conv_res1 = MPConv(out_channels, out_channels, kernel=[3,3])
|
143 |
+
self.conv_skip = MPConv(in_channels, out_channels, kernel=[1,1]) if in_channels != out_channels else None
|
144 |
+
self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=[1,1]) if self.num_heads != 0 else None
|
145 |
+
self.attn_proj = MPConv(out_channels, out_channels, kernel=[1,1]) if self.num_heads != 0 else None
|
146 |
+
|
147 |
+
def forward(self, x, emb):
|
148 |
+
# Main branch.
|
149 |
+
x = resample(x, f=self.resample_filter, mode=self.resample_mode)
|
150 |
+
if self.flavor == 'enc':
|
151 |
+
if self.conv_skip is not None:
|
152 |
+
x = self.conv_skip(x)
|
153 |
+
x = normalize(x, dim=1) # pixel norm
|
154 |
+
|
155 |
+
# Residual branch.
|
156 |
+
y = self.conv_res0(mp_silu(x))
|
157 |
+
c = self.emb_linear(emb, gain=self.emb_gain) + 1
|
158 |
+
y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
|
159 |
+
if self.training and self.dropout != 0:
|
160 |
+
y = torch.nn.functional.dropout(y, p=self.dropout)
|
161 |
+
y = self.conv_res1(y)
|
162 |
+
|
163 |
+
# Connect the branches.
|
164 |
+
if self.flavor == 'dec' and self.conv_skip is not None:
|
165 |
+
x = self.conv_skip(x)
|
166 |
+
x = mp_sum(x, y, t=self.res_balance)
|
167 |
+
|
168 |
+
# Self-attention.
|
169 |
+
# Note: torch.nn.functional.scaled_dot_product_attention() could be used here,
|
170 |
+
# but we haven't done sufficient testing to verify that it produces identical results.
|
171 |
+
if self.num_heads != 0:
|
172 |
+
y = self.attn_qkv(x)
|
173 |
+
y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
|
174 |
+
q, k, v = normalize(y, dim=2).unbind(3) # pixel norm & split
|
175 |
+
w = torch.einsum('nhcq,nhck->nhqk', q, k / np.sqrt(q.shape[2])).softmax(dim=3)
|
176 |
+
y = torch.einsum('nhqk,nhck->nhcq', w, v)
|
177 |
+
y = self.attn_proj(y.reshape(*x.shape))
|
178 |
+
x = mp_sum(x, y, t=self.attn_balance)
|
179 |
+
|
180 |
+
# Clip activations.
|
181 |
+
if self.clip_act is not None:
|
182 |
+
x = x.clip_(-self.clip_act, self.clip_act)
|
183 |
+
return x
|
184 |
+
|
185 |
+
#----------------------------------------------------------------------------
|
186 |
+
# EDM2 U-Net model (Figure 21).
|
187 |
+
|
188 |
+
@persistence.persistent_class
|
189 |
+
class UNet(torch.nn.Module):
|
190 |
+
def __init__(self,
|
191 |
+
img_resolution, # Image resolution.
|
192 |
+
img_channels, # Image channels.
|
193 |
+
label_dim, # Class label dimensionality. 0 = unconditional.
|
194 |
+
model_channels = 192, # Base multiplier for the number of channels.
|
195 |
+
channel_mult = [1,2,3,4], # Per-resolution multipliers for the number of channels.
|
196 |
+
channel_mult_noise = None, # Multiplier for noise embedding dimensionality. None = select based on channel_mult.
|
197 |
+
channel_mult_emb = None, # Multiplier for final embedding dimensionality. None = select based on channel_mult.
|
198 |
+
num_blocks = 3, # Number of residual blocks per resolution.
|
199 |
+
attn_resolutions = [16,8], # List of resolutions with self-attention.
|
200 |
+
label_balance = 0.5, # Balance between noise embedding (0) and class embedding (1).
|
201 |
+
concat_balance = 0.5, # Balance between skip connections (0) and main path (1).
|
202 |
+
**block_kwargs, # Arguments for Block.
|
203 |
+
):
|
204 |
+
super().__init__()
|
205 |
+
cblock = [model_channels * x for x in channel_mult]
|
206 |
+
cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
|
207 |
+
cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
|
208 |
+
self.label_balance = label_balance
|
209 |
+
self.concat_balance = concat_balance
|
210 |
+
self.out_gain = torch.nn.Parameter(torch.zeros([]))
|
211 |
+
|
212 |
+
# Embedding.
|
213 |
+
self.emb_fourier = MPFourier(cnoise)
|
214 |
+
self.emb_noise = MPConv(cnoise, cemb, kernel=[])
|
215 |
+
self.emb_label = MPConv(label_dim, cemb, kernel=[]) if label_dim != 0 else None
|
216 |
+
|
217 |
+
# Encoder.
|
218 |
+
self.enc = torch.nn.ModuleDict()
|
219 |
+
cout = img_channels + 1
|
220 |
+
for level, channels in enumerate(cblock):
|
221 |
+
res = img_resolution >> level
|
222 |
+
if level == 0:
|
223 |
+
cin = cout
|
224 |
+
cout = channels
|
225 |
+
self.enc[f'{res}x{res}_conv'] = MPConv(cin, cout, kernel=[3,3])
|
226 |
+
else:
|
227 |
+
self.enc[f'{res}x{res}_down'] = Block(cout, cout, cemb, flavor='enc', resample_mode='down', **block_kwargs)
|
228 |
+
for idx in range(num_blocks):
|
229 |
+
cin = cout
|
230 |
+
cout = channels
|
231 |
+
self.enc[f'{res}x{res}_block{idx}'] = Block(cin, cout, cemb, flavor='enc', attention=(res in attn_resolutions), **block_kwargs)
|
232 |
+
|
233 |
+
# Decoder.
|
234 |
+
self.dec = torch.nn.ModuleDict()
|
235 |
+
skips = [block.out_channels for block in self.enc.values()]
|
236 |
+
for level, channels in reversed(list(enumerate(cblock))):
|
237 |
+
res = img_resolution >> level
|
238 |
+
if level == len(cblock) - 1:
|
239 |
+
self.dec[f'{res}x{res}_in0'] = Block(cout, cout, cemb, flavor='dec', attention=True, **block_kwargs)
|
240 |
+
self.dec[f'{res}x{res}_in1'] = Block(cout, cout, cemb, flavor='dec', **block_kwargs)
|
241 |
+
else:
|
242 |
+
self.dec[f'{res}x{res}_up'] = Block(cout, cout, cemb, flavor='dec', resample_mode='up', **block_kwargs)
|
243 |
+
for idx in range(num_blocks + 1):
|
244 |
+
cin = cout + skips.pop()
|
245 |
+
cout = channels
|
246 |
+
self.dec[f'{res}x{res}_block{idx}'] = Block(cin, cout, cemb, flavor='dec', attention=(res in attn_resolutions), **block_kwargs)
|
247 |
+
self.out_conv = MPConv(cout, img_channels, kernel=[3,3])
|
248 |
+
|
249 |
+
def forward(self, x, noise_labels, class_labels):
|
250 |
+
# Embedding.
|
251 |
+
emb = self.emb_noise(self.emb_fourier(noise_labels))
|
252 |
+
if self.emb_label is not None:
|
253 |
+
emb = mp_sum(emb, self.emb_label(class_labels * np.sqrt(class_labels.shape[1])), t=self.label_balance)
|
254 |
+
emb = mp_silu(emb)
|
255 |
+
|
256 |
+
# Encoder.
|
257 |
+
x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
|
258 |
+
skips = []
|
259 |
+
for name, block in self.enc.items():
|
260 |
+
x = block(x) if 'conv' in name else block(x, emb)
|
261 |
+
skips.append(x)
|
262 |
+
|
263 |
+
# Decoder.
|
264 |
+
for name, block in self.dec.items():
|
265 |
+
if 'block' in name:
|
266 |
+
x = mp_cat(x, skips.pop(), t=self.concat_balance)
|
267 |
+
x = block(x, emb)
|
268 |
+
x = self.out_conv(x, gain=self.out_gain)
|
269 |
+
return x
|
270 |
+
|
271 |
+
#----------------------------------------------------------------------------
|
272 |
+
# Preconditioning and uncertainty estimation.
|
273 |
+
|
274 |
+
@persistence.persistent_class
|
275 |
+
class Precond(torch.nn.Module):
|
276 |
+
def __init__(self,
|
277 |
+
img_resolution, # Image resolution.
|
278 |
+
img_channels, # Image channels.
|
279 |
+
label_dim, # Class label dimensionality. 0 = unconditional.
|
280 |
+
use_fp16 = True, # Run the model at FP16 precision?
|
281 |
+
sigma_data = 0.5, # Expected standard deviation of the training data.
|
282 |
+
logvar_channels = 128, # Intermediate dimensionality for uncertainty estimation.
|
283 |
+
**unet_kwargs, # Keyword arguments for UNet.
|
284 |
+
):
|
285 |
+
super().__init__()
|
286 |
+
self.img_resolution = img_resolution
|
287 |
+
self.img_channels = img_channels
|
288 |
+
self.label_dim = label_dim
|
289 |
+
self.use_fp16 = use_fp16
|
290 |
+
self.sigma_data = sigma_data
|
291 |
+
self.unet = UNet(img_resolution=img_resolution, img_channels=img_channels, label_dim=label_dim, **unet_kwargs)
|
292 |
+
self.logvar_fourier = MPFourier(logvar_channels)
|
293 |
+
self.logvar_linear = MPConv(logvar_channels, 1, kernel=[])
|
294 |
+
|
295 |
+
def forward(self, x, sigma, class_labels=None, force_fp32=False, return_logvar=False, **unet_kwargs):
|
296 |
+
x = x.to(torch.float32)
|
297 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
298 |
+
class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
|
299 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
|
300 |
+
|
301 |
+
# Preconditioning weights.
|
302 |
+
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
303 |
+
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
|
304 |
+
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
|
305 |
+
c_noise = sigma.flatten().log() / 4
|
306 |
+
|
307 |
+
# Run the model.
|
308 |
+
x_in = (c_in * x).to(dtype)
|
309 |
+
F_x = self.unet(x_in, c_noise, class_labels, **unet_kwargs)
|
310 |
+
D_x = c_skip * x + c_out * F_x.to(torch.float32)
|
311 |
+
|
312 |
+
# Estimate uncertainty if requested.
|
313 |
+
if return_logvar:
|
314 |
+
logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
|
315 |
+
return D_x, logvar # u(sigma) in Equation 21
|
316 |
+
return D_x
|
317 |
+
|
318 |
+
#----------------------------------------------------------------------------
|