ahsanMah commited on
Commit
f22f03c
·
1 Parent(s): d904b3a

+ HF models now built with config not pickle

Browse files
Files changed (3) hide show
  1. app.py +8 -4
  2. msma.py +22 -12
  3. networks_edm2.py +318 -0
app.py CHANGED
@@ -11,7 +11,12 @@ import torch
11
  from huggingface_hub import hf_hub_download
12
  from safetensors.torch import load_file
13
 
14
- from msma import ScoreFlow, build_model_from_pickle, config_presets
 
 
 
 
 
15
 
16
 
17
  @cache
@@ -32,7 +37,6 @@ def load_model_from_hub(preset, device):
32
  if 'DNNLIB_CACHE_DIR' in os.environ:
33
  cache_dir = os.environ["DNNLIB_CACHE_DIR"]
34
 
35
- scorenet = build_model_from_pickle(preset)
36
 
37
  for fname in ['config.json', 'gmm.pkl', 'refscores.npz', 'model.safetensors' ]:
38
  cached_fname = hf_hub_download(
@@ -49,10 +53,10 @@ def load_model_from_hub(preset, device):
49
  print("Loaded:", model_params)
50
 
51
  hf_checkpoint = f"{modeldir}/model.safetensors"
52
- model = ScoreFlow(scorenet, device=device, **model_params["PatchFlow"])
53
  model.load_state_dict(load_file(hf_checkpoint), strict=True)
54
  model = model.eval().requires_grad_(False)
55
-
56
  return model, modeldir
57
 
58
 
 
11
  from huggingface_hub import hf_hub_download
12
  from safetensors.torch import load_file
13
 
14
+ from msma import (
15
+ ScoreFlow,
16
+ build_model_from_config,
17
+ build_model_from_pickle,
18
+ config_presets,
19
+ )
20
 
21
 
22
  @cache
 
37
  if 'DNNLIB_CACHE_DIR' in os.environ:
38
  cache_dir = os.environ["DNNLIB_CACHE_DIR"]
39
 
 
40
 
41
  for fname in ['config.json', 'gmm.pkl', 'refscores.npz', 'model.safetensors' ]:
42
  cached_fname = hf_hub_download(
 
53
  print("Loaded:", model_params)
54
 
55
  hf_checkpoint = f"{modeldir}/model.safetensors"
56
+ model = build_model_from_config(model_params)
57
  model.load_state_dict(load_file(hf_checkpoint), strict=True)
58
  model = model.eval().requires_grad_(False)
59
+ model.to(device)
60
  return model, modeldir
61
 
62
 
msma.py CHANGED
@@ -21,6 +21,7 @@ from tqdm import tqdm
21
  import dnnlib
22
  from dataset import ImageFolderDataset
23
  from flowutils import PatchFlow, sanitize_locals
 
24
 
25
  DEVICE: Literal["cuda", "cpu"] = "cpu"
26
  model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
@@ -122,6 +123,14 @@ class ScoreFlow(torch.nn.Module):
122
  return self.flow(x_scores)
123
 
124
 
 
 
 
 
 
 
 
 
125
  def build_model_from_pickle(preset="edm2-img64-s-fid", device="cpu"):
126
  netpath = config_presets[preset]
127
  with dnnlib.util.open_url(netpath, verbose=1) as f:
@@ -196,13 +205,13 @@ def cmdline():
196
  def common_args(func):
197
  @wraps(func)
198
  @click.option(
199
- "--preset",
200
- help="Configuration preset",
201
- metavar="STR",
202
- type=str,
203
- default="edm2-img64-s-fid",
204
- show_default=True,
205
- )
206
  @click.option(
207
  "--dataset_path",
208
  help="Path to the dataset",
@@ -222,7 +231,8 @@ def common_args(func):
222
 
223
  return wrapper
224
 
225
- @cmdline.command('train-gmm')
 
226
  @click.option(
227
  "--gridsearch",
228
  help="Whether to use a grid search on a number of components to find the best fit",
@@ -365,7 +375,7 @@ def train_flow(dataset_path, preset, outdir, epochs, batch_size, **flow_kwargs):
365
  train_ds, batch_size=batch_size, num_workers=4, prefetch_factor=2, shuffle=True
366
  )
367
  testiter = torch.utils.data.DataLoader(
368
- val_ds, batch_size=batch_size*2, num_workers=4, prefetch_factor=2
369
  )
370
 
371
  scorenet = build_model_from_pickle(preset)
@@ -392,10 +402,10 @@ def train_flow(dataset_path, preset, outdir, epochs, batch_size, **flow_kwargs):
392
  timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M")
393
  writer = SummaryWriter(f"{experiment_dir}/logs/{timestamp}")
394
 
395
- with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
396
  json.dump(model.config, f, sort_keys=True, indent=4)
397
 
398
- with open(f"{experiment_dir}/config.json", "w") as f:
399
  json.dump(model.config, f, sort_keys=True, indent=4)
400
 
401
  # totaliters = int(epochs * train_len)
@@ -463,7 +473,7 @@ def train_flow(dataset_path, preset, outdir, epochs, batch_size, **flow_kwargs):
463
 
464
  # Save final model
465
  torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
466
-
467
  writer.close()
468
 
469
 
 
21
  import dnnlib
22
  from dataset import ImageFolderDataset
23
  from flowutils import PatchFlow, sanitize_locals
24
+ from networks_edm2 import Precond
25
 
26
  DEVICE: Literal["cuda", "cpu"] = "cpu"
27
  model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
 
123
  return self.flow(x_scores)
124
 
125
 
126
+ def build_model_from_config(model_params):
127
+ net = Precond(**model_params["EDMNet"])
128
+ scorenet = EDMScorer(net=net, **model_params["EDMScorer"])
129
+ scoreflow = ScoreFlow(scorenet=scorenet, **model_params["PatchFlow"])
130
+ print("Built model from config")
131
+ return scoreflow
132
+
133
+
134
  def build_model_from_pickle(preset="edm2-img64-s-fid", device="cpu"):
135
  netpath = config_presets[preset]
136
  with dnnlib.util.open_url(netpath, verbose=1) as f:
 
205
  def common_args(func):
206
  @wraps(func)
207
  @click.option(
208
+ "--preset",
209
+ help="Configuration preset",
210
+ metavar="STR",
211
+ type=str,
212
+ default="edm2-img64-s-fid",
213
+ show_default=True,
214
+ )
215
  @click.option(
216
  "--dataset_path",
217
  help="Path to the dataset",
 
231
 
232
  return wrapper
233
 
234
+
235
+ @cmdline.command("train-gmm")
236
  @click.option(
237
  "--gridsearch",
238
  help="Whether to use a grid search on a number of components to find the best fit",
 
375
  train_ds, batch_size=batch_size, num_workers=4, prefetch_factor=2, shuffle=True
376
  )
377
  testiter = torch.utils.data.DataLoader(
378
+ val_ds, batch_size=batch_size * 2, num_workers=4, prefetch_factor=2
379
  )
380
 
381
  scorenet = build_model_from_pickle(preset)
 
402
  timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M")
403
  writer = SummaryWriter(f"{experiment_dir}/logs/{timestamp}")
404
 
405
+ with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
406
  json.dump(model.config, f, sort_keys=True, indent=4)
407
 
408
+ with open(f"{experiment_dir}/config.json", "w") as f:
409
  json.dump(model.config, f, sort_keys=True, indent=4)
410
 
411
  # totaliters = int(epochs * train_len)
 
473
 
474
  # Save final model
475
  torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
476
+
477
  writer.close()
478
 
479
 
networks_edm2.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Improved diffusion model architecture proposed in the paper
9
+ "Analyzing and Improving the Training Dynamics of Diffusion Models"."""
10
+
11
+ import numpy as np
12
+ import torch
13
+ from torch_utils import persistence
14
+ from torch_utils import misc
15
+
16
+ #----------------------------------------------------------------------------
17
+ # Normalize given tensor to unit magnitude with respect to the given
18
+ # dimensions. Default = all dimensions except the first.
19
+
20
+ def normalize(x, dim=None, eps=1e-4):
21
+ if dim is None:
22
+ dim = list(range(1, x.ndim))
23
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
24
+ norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
25
+ return x / norm.to(x.dtype)
26
+
27
+ #----------------------------------------------------------------------------
28
+ # Upsample or downsample the given tensor with the given filter,
29
+ # or keep it as is.
30
+
31
+ def resample(x, f=[1,1], mode='keep'):
32
+ if mode == 'keep':
33
+ return x
34
+ f = np.float32(f)
35
+ assert f.ndim == 1 and len(f) % 2 == 0
36
+ pad = (len(f) - 1) // 2
37
+ f = f / f.sum()
38
+ f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
39
+ f = misc.const_like(x, f)
40
+ c = x.shape[1]
41
+ if mode == 'down':
42
+ return torch.nn.functional.conv2d(x, f.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
43
+ assert mode == 'up'
44
+ return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
45
+
46
+ #----------------------------------------------------------------------------
47
+ # Magnitude-preserving SiLU (Equation 81).
48
+
49
+ def mp_silu(x):
50
+ return torch.nn.functional.silu(x) / 0.596
51
+
52
+ #----------------------------------------------------------------------------
53
+ # Magnitude-preserving sum (Equation 88).
54
+
55
+ def mp_sum(a, b, t=0.5):
56
+ return a.lerp(b, t) / np.sqrt((1 - t) ** 2 + t ** 2)
57
+
58
+ #----------------------------------------------------------------------------
59
+ # Magnitude-preserving concatenation (Equation 103).
60
+
61
+ def mp_cat(a, b, dim=1, t=0.5):
62
+ Na = a.shape[dim]
63
+ Nb = b.shape[dim]
64
+ C = np.sqrt((Na + Nb) / ((1 - t) ** 2 + t ** 2))
65
+ wa = C / np.sqrt(Na) * (1 - t)
66
+ wb = C / np.sqrt(Nb) * t
67
+ return torch.cat([wa * a , wb * b], dim=dim)
68
+
69
+ #----------------------------------------------------------------------------
70
+ # Magnitude-preserving Fourier features (Equation 75).
71
+
72
+ @persistence.persistent_class
73
+ class MPFourier(torch.nn.Module):
74
+ def __init__(self, num_channels, bandwidth=1):
75
+ super().__init__()
76
+ self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
77
+ self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))
78
+
79
+ def forward(self, x):
80
+ y = x.to(torch.float32)
81
+ y = y.ger(self.freqs.to(torch.float32))
82
+ y = y + self.phases.to(torch.float32)
83
+ y = y.cos() * np.sqrt(2)
84
+ return y.to(x.dtype)
85
+
86
+ #----------------------------------------------------------------------------
87
+ # Magnitude-preserving convolution or fully-connected layer (Equation 47)
88
+ # with force weight normalization (Equation 66).
89
+
90
+ @persistence.persistent_class
91
+ class MPConv(torch.nn.Module):
92
+ def __init__(self, in_channels, out_channels, kernel):
93
+ super().__init__()
94
+ self.out_channels = out_channels
95
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
96
+
97
+ def forward(self, x, gain=1):
98
+ w = self.weight.to(torch.float32)
99
+ if self.training:
100
+ with torch.no_grad():
101
+ self.weight.copy_(normalize(w)) # forced weight normalization
102
+ w = normalize(w) # traditional weight normalization
103
+ w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
104
+ w = w.to(x.dtype)
105
+ if w.ndim == 2:
106
+ return x @ w.t()
107
+ assert w.ndim == 4
108
+ return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
109
+
110
+ #----------------------------------------------------------------------------
111
+ # U-Net encoder/decoder block with optional self-attention (Figure 21).
112
+
113
+ @persistence.persistent_class
114
+ class Block(torch.nn.Module):
115
+ def __init__(self,
116
+ in_channels, # Number of input channels.
117
+ out_channels, # Number of output channels.
118
+ emb_channels, # Number of embedding channels.
119
+ flavor = 'enc', # Flavor: 'enc' or 'dec'.
120
+ resample_mode = 'keep', # Resampling: 'keep', 'up', or 'down'.
121
+ resample_filter = [1,1], # Resampling filter.
122
+ attention = False, # Include self-attention?
123
+ channels_per_head = 64, # Number of channels per attention head.
124
+ dropout = 0, # Dropout probability.
125
+ res_balance = 0.3, # Balance between main branch (0) and residual branch (1).
126
+ attn_balance = 0.3, # Balance between main branch (0) and self-attention (1).
127
+ clip_act = 256, # Clip output activations. None = do not clip.
128
+ ):
129
+ super().__init__()
130
+ self.out_channels = out_channels
131
+ self.flavor = flavor
132
+ self.resample_filter = resample_filter
133
+ self.resample_mode = resample_mode
134
+ self.num_heads = out_channels // channels_per_head if attention else 0
135
+ self.dropout = dropout
136
+ self.res_balance = res_balance
137
+ self.attn_balance = attn_balance
138
+ self.clip_act = clip_act
139
+ self.emb_gain = torch.nn.Parameter(torch.zeros([]))
140
+ self.conv_res0 = MPConv(out_channels if flavor == 'enc' else in_channels, out_channels, kernel=[3,3])
141
+ self.emb_linear = MPConv(emb_channels, out_channels, kernel=[])
142
+ self.conv_res1 = MPConv(out_channels, out_channels, kernel=[3,3])
143
+ self.conv_skip = MPConv(in_channels, out_channels, kernel=[1,1]) if in_channels != out_channels else None
144
+ self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=[1,1]) if self.num_heads != 0 else None
145
+ self.attn_proj = MPConv(out_channels, out_channels, kernel=[1,1]) if self.num_heads != 0 else None
146
+
147
+ def forward(self, x, emb):
148
+ # Main branch.
149
+ x = resample(x, f=self.resample_filter, mode=self.resample_mode)
150
+ if self.flavor == 'enc':
151
+ if self.conv_skip is not None:
152
+ x = self.conv_skip(x)
153
+ x = normalize(x, dim=1) # pixel norm
154
+
155
+ # Residual branch.
156
+ y = self.conv_res0(mp_silu(x))
157
+ c = self.emb_linear(emb, gain=self.emb_gain) + 1
158
+ y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
159
+ if self.training and self.dropout != 0:
160
+ y = torch.nn.functional.dropout(y, p=self.dropout)
161
+ y = self.conv_res1(y)
162
+
163
+ # Connect the branches.
164
+ if self.flavor == 'dec' and self.conv_skip is not None:
165
+ x = self.conv_skip(x)
166
+ x = mp_sum(x, y, t=self.res_balance)
167
+
168
+ # Self-attention.
169
+ # Note: torch.nn.functional.scaled_dot_product_attention() could be used here,
170
+ # but we haven't done sufficient testing to verify that it produces identical results.
171
+ if self.num_heads != 0:
172
+ y = self.attn_qkv(x)
173
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
174
+ q, k, v = normalize(y, dim=2).unbind(3) # pixel norm & split
175
+ w = torch.einsum('nhcq,nhck->nhqk', q, k / np.sqrt(q.shape[2])).softmax(dim=3)
176
+ y = torch.einsum('nhqk,nhck->nhcq', w, v)
177
+ y = self.attn_proj(y.reshape(*x.shape))
178
+ x = mp_sum(x, y, t=self.attn_balance)
179
+
180
+ # Clip activations.
181
+ if self.clip_act is not None:
182
+ x = x.clip_(-self.clip_act, self.clip_act)
183
+ return x
184
+
185
+ #----------------------------------------------------------------------------
186
+ # EDM2 U-Net model (Figure 21).
187
+
188
+ @persistence.persistent_class
189
+ class UNet(torch.nn.Module):
190
+ def __init__(self,
191
+ img_resolution, # Image resolution.
192
+ img_channels, # Image channels.
193
+ label_dim, # Class label dimensionality. 0 = unconditional.
194
+ model_channels = 192, # Base multiplier for the number of channels.
195
+ channel_mult = [1,2,3,4], # Per-resolution multipliers for the number of channels.
196
+ channel_mult_noise = None, # Multiplier for noise embedding dimensionality. None = select based on channel_mult.
197
+ channel_mult_emb = None, # Multiplier for final embedding dimensionality. None = select based on channel_mult.
198
+ num_blocks = 3, # Number of residual blocks per resolution.
199
+ attn_resolutions = [16,8], # List of resolutions with self-attention.
200
+ label_balance = 0.5, # Balance between noise embedding (0) and class embedding (1).
201
+ concat_balance = 0.5, # Balance between skip connections (0) and main path (1).
202
+ **block_kwargs, # Arguments for Block.
203
+ ):
204
+ super().__init__()
205
+ cblock = [model_channels * x for x in channel_mult]
206
+ cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
207
+ cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
208
+ self.label_balance = label_balance
209
+ self.concat_balance = concat_balance
210
+ self.out_gain = torch.nn.Parameter(torch.zeros([]))
211
+
212
+ # Embedding.
213
+ self.emb_fourier = MPFourier(cnoise)
214
+ self.emb_noise = MPConv(cnoise, cemb, kernel=[])
215
+ self.emb_label = MPConv(label_dim, cemb, kernel=[]) if label_dim != 0 else None
216
+
217
+ # Encoder.
218
+ self.enc = torch.nn.ModuleDict()
219
+ cout = img_channels + 1
220
+ for level, channels in enumerate(cblock):
221
+ res = img_resolution >> level
222
+ if level == 0:
223
+ cin = cout
224
+ cout = channels
225
+ self.enc[f'{res}x{res}_conv'] = MPConv(cin, cout, kernel=[3,3])
226
+ else:
227
+ self.enc[f'{res}x{res}_down'] = Block(cout, cout, cemb, flavor='enc', resample_mode='down', **block_kwargs)
228
+ for idx in range(num_blocks):
229
+ cin = cout
230
+ cout = channels
231
+ self.enc[f'{res}x{res}_block{idx}'] = Block(cin, cout, cemb, flavor='enc', attention=(res in attn_resolutions), **block_kwargs)
232
+
233
+ # Decoder.
234
+ self.dec = torch.nn.ModuleDict()
235
+ skips = [block.out_channels for block in self.enc.values()]
236
+ for level, channels in reversed(list(enumerate(cblock))):
237
+ res = img_resolution >> level
238
+ if level == len(cblock) - 1:
239
+ self.dec[f'{res}x{res}_in0'] = Block(cout, cout, cemb, flavor='dec', attention=True, **block_kwargs)
240
+ self.dec[f'{res}x{res}_in1'] = Block(cout, cout, cemb, flavor='dec', **block_kwargs)
241
+ else:
242
+ self.dec[f'{res}x{res}_up'] = Block(cout, cout, cemb, flavor='dec', resample_mode='up', **block_kwargs)
243
+ for idx in range(num_blocks + 1):
244
+ cin = cout + skips.pop()
245
+ cout = channels
246
+ self.dec[f'{res}x{res}_block{idx}'] = Block(cin, cout, cemb, flavor='dec', attention=(res in attn_resolutions), **block_kwargs)
247
+ self.out_conv = MPConv(cout, img_channels, kernel=[3,3])
248
+
249
+ def forward(self, x, noise_labels, class_labels):
250
+ # Embedding.
251
+ emb = self.emb_noise(self.emb_fourier(noise_labels))
252
+ if self.emb_label is not None:
253
+ emb = mp_sum(emb, self.emb_label(class_labels * np.sqrt(class_labels.shape[1])), t=self.label_balance)
254
+ emb = mp_silu(emb)
255
+
256
+ # Encoder.
257
+ x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
258
+ skips = []
259
+ for name, block in self.enc.items():
260
+ x = block(x) if 'conv' in name else block(x, emb)
261
+ skips.append(x)
262
+
263
+ # Decoder.
264
+ for name, block in self.dec.items():
265
+ if 'block' in name:
266
+ x = mp_cat(x, skips.pop(), t=self.concat_balance)
267
+ x = block(x, emb)
268
+ x = self.out_conv(x, gain=self.out_gain)
269
+ return x
270
+
271
+ #----------------------------------------------------------------------------
272
+ # Preconditioning and uncertainty estimation.
273
+
274
+ @persistence.persistent_class
275
+ class Precond(torch.nn.Module):
276
+ def __init__(self,
277
+ img_resolution, # Image resolution.
278
+ img_channels, # Image channels.
279
+ label_dim, # Class label dimensionality. 0 = unconditional.
280
+ use_fp16 = True, # Run the model at FP16 precision?
281
+ sigma_data = 0.5, # Expected standard deviation of the training data.
282
+ logvar_channels = 128, # Intermediate dimensionality for uncertainty estimation.
283
+ **unet_kwargs, # Keyword arguments for UNet.
284
+ ):
285
+ super().__init__()
286
+ self.img_resolution = img_resolution
287
+ self.img_channels = img_channels
288
+ self.label_dim = label_dim
289
+ self.use_fp16 = use_fp16
290
+ self.sigma_data = sigma_data
291
+ self.unet = UNet(img_resolution=img_resolution, img_channels=img_channels, label_dim=label_dim, **unet_kwargs)
292
+ self.logvar_fourier = MPFourier(logvar_channels)
293
+ self.logvar_linear = MPConv(logvar_channels, 1, kernel=[])
294
+
295
+ def forward(self, x, sigma, class_labels=None, force_fp32=False, return_logvar=False, **unet_kwargs):
296
+ x = x.to(torch.float32)
297
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
298
+ class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
299
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
300
+
301
+ # Preconditioning weights.
302
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
303
+ c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
304
+ c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
305
+ c_noise = sigma.flatten().log() / 4
306
+
307
+ # Run the model.
308
+ x_in = (c_in * x).to(dtype)
309
+ F_x = self.unet(x_in, c_noise, class_labels, **unet_kwargs)
310
+ D_x = c_skip * x + c_out * F_x.to(torch.float32)
311
+
312
+ # Estimate uncertainty if requested.
313
+ if return_logvar:
314
+ logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
315
+ return D_x, logvar # u(sigma) in Equation 21
316
+ return D_x
317
+
318
+ #----------------------------------------------------------------------------