ahsanMah commited on
Commit
6e569fe
·
1 Parent(s): 95a02fd

saving configs and making models easier to load

Browse files
Files changed (2) hide show
  1. flowutils.py +63 -44
  2. msma.py +32 -15
flowutils.py CHANGED
@@ -8,6 +8,66 @@ from einops import rearrange, repeat
8
  from normflows.distributions import BaseDistribution
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class ConditionalDiagGaussian(BaseDistribution):
12
  """
13
  Conditional multivariate Gaussian distribution with diagonal
@@ -61,50 +121,6 @@ class ConditionalDiagGaussian(BaseDistribution):
61
  return log_p
62
 
63
 
64
- def build_flows(
65
- latent_size, num_flows=4, num_blocks_per_flow=2, hidden_units=128, context_size=64
66
- ):
67
- # Define flows
68
-
69
- flows = []
70
-
71
- flows.append(
72
- nf.flows.MaskedAffineAutoregressive(
73
- latent_size,
74
- hidden_features=hidden_units,
75
- num_blocks=num_blocks_per_flow,
76
- context_features=context_size,
77
- )
78
- )
79
-
80
- for i in range(num_flows):
81
- flows += [
82
- nf.flows.CoupledRationalQuadraticSpline(
83
- latent_size,
84
- num_blocks=num_blocks_per_flow,
85
- num_hidden_channels=hidden_units,
86
- num_context_channels=context_size,
87
- )
88
- ]
89
- flows += [nf.flows.LULinearPermute(latent_size)]
90
-
91
- # Set base distribution
92
-
93
- context_encoder = nn.Sequential(
94
- nn.Linear(context_size, context_size),
95
- nn.SiLU(),
96
- # output mean and scales for K=latent_size dimensions
97
- nn.Linear(context_size, latent_size * 2),
98
- )
99
-
100
- q0 = ConditionalDiagGaussian(latent_size, context_encoder)
101
-
102
- # Construct flow model
103
- model = nf.ConditionalNormalizingFlow(q0, flows)
104
-
105
- return model
106
-
107
-
108
  def get_emb(sin_inp):
109
  """
110
  Gets a base embedding for one dimension with sin and cos intertwined
@@ -204,6 +220,9 @@ class PatchFlow(torch.nn.Module):
204
  hidden_units=128,
205
  ):
206
  super().__init__()
 
 
 
207
  num_sigmas, c, h, w = input_size
208
  self.local_pooler = SpatialNormer(
209
  in_channels=num_sigmas, kernel_size=patch_size
 
8
  from normflows.distributions import BaseDistribution
9
 
10
 
11
+ def sanitize_locals(args_dict, ignore_keys=None):
12
+
13
+ if ignore_keys is None:
14
+ ignore_keys = []
15
+
16
+ if not isinstance(ignore_keys, list):
17
+ ignore_keys = [ignore_keys]
18
+
19
+ _dict = args_dict.copy()
20
+ _dict.pop("self")
21
+ class_name = _dict.pop("__class__").__name__
22
+ class_params = {k: v for k, v in _dict.items() if k not in ignore_keys}
23
+
24
+ return {class_name: class_params}
25
+
26
+
27
+ def build_flows(
28
+ latent_size, num_flows=4, num_blocks_per_flow=2, hidden_units=128, context_size=64
29
+ ):
30
+ # Define flows
31
+
32
+ flows = []
33
+
34
+ flows.append(
35
+ nf.flows.MaskedAffineAutoregressive(
36
+ latent_size,
37
+ hidden_features=hidden_units,
38
+ num_blocks=num_blocks_per_flow,
39
+ context_features=context_size,
40
+ )
41
+ )
42
+
43
+ for i in range(num_flows):
44
+ flows += [
45
+ nf.flows.CoupledRationalQuadraticSpline(
46
+ latent_size,
47
+ num_blocks=num_blocks_per_flow,
48
+ num_hidden_channels=hidden_units,
49
+ num_context_channels=context_size,
50
+ )
51
+ ]
52
+ flows += [nf.flows.LULinearPermute(latent_size)]
53
+
54
+ # Set base distribution
55
+
56
+ context_encoder = nn.Sequential(
57
+ nn.Linear(context_size, context_size),
58
+ nn.SiLU(),
59
+ # output mean and scales for K=latent_size dimensions
60
+ nn.Linear(context_size, latent_size * 2),
61
+ )
62
+
63
+ q0 = ConditionalDiagGaussian(latent_size, context_encoder)
64
+
65
+ # Construct flow model
66
+ model = nf.ConditionalNormalizingFlow(q0, flows)
67
+
68
+ return model
69
+
70
+
71
  class ConditionalDiagGaussian(BaseDistribution):
72
  """
73
  Conditional multivariate Gaussian distribution with diagonal
 
121
  return log_p
122
 
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def get_emb(sin_inp):
125
  """
126
  Gets a base embedding for one dimension with sin and cos intertwined
 
220
  hidden_units=128,
221
  ):
222
  super().__init__()
223
+
224
+ self.config = sanitize_locals(locals(), ignore_keys=input_size)
225
+
226
  num_sigmas, c, h, w = input_size
227
  self.local_pooler = SpatialNormer(
228
  in_channels=num_sigmas, kernel_size=patch_size
msma.py CHANGED
@@ -19,7 +19,7 @@ from tqdm import tqdm
19
 
20
  import dnnlib
21
  from dataset import ImageFolderDataset
22
- from flowutils import PatchFlow
23
 
24
  DEVICE: Literal["cuda", "cpu"] = 'cpu'
25
  model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
@@ -53,9 +53,12 @@ class EDMScorer(torch.nn.Module):
53
  sigma_max=80, # Maximum supported noise level.
54
  sigma_data=0.5, # Expected standard deviation of the training data.
55
  rho=7, # Time step discretization.
56
- device=torch.device("cpu"), # Device to use.
57
  ):
58
  super().__init__()
 
 
 
 
59
  self.use_fp16 = use_fp16
60
  self.sigma_min = sigma_min
61
  self.sigma_max = sigma_max
@@ -67,14 +70,13 @@ class EDMScorer(torch.nn.Module):
67
  self.sigma_min = 1e-1
68
  self.sigma_max = sigma_max * stop_ratio
69
 
70
- step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
71
  t_steps = (
72
  self.sigma_max ** (1 / rho)
73
  + step_indices
74
  / (num_steps - 1)
75
  * (self.sigma_min ** (1 / rho) - self.sigma_max ** (1 / rho))
76
  ) ** rho
77
- # print("Using steps:", t_steps)
78
 
79
  self.register_buffer("sigma_steps", t_steps.to(torch.float64))
80
 
@@ -100,28 +102,32 @@ class EDMScorer(torch.nn.Module):
100
  class ScoreFlow(torch.nn.Module):
101
  def __init__(
102
  self,
103
- preset,
104
  device="cpu",
105
  **flow_kwargs
106
  ):
107
  super().__init__()
108
 
109
- scorenet = build_model(preset)
110
  h = w = scorenet.net.img_resolution
111
  c = scorenet.net.img_channels
112
  num_sigmas = len(scorenet.sigma_steps)
113
  self.flow = PatchFlow((num_sigmas, c, h, w), **flow_kwargs)
 
114
 
115
  self.flow = self.flow.to(device)
116
  self.scorenet = scorenet.to(device).requires_grad_(False)
117
  self.flow.init_weights()
118
 
 
 
 
 
119
  def forward(self, x, **score_kwargs):
120
  x_scores = self.scorenet(x, **score_kwargs)
121
  return self.flow(x_scores)
122
 
123
 
124
- def build_model(preset="edm2-img64-s-fid", device="cpu"):
125
  netpath = config_presets[preset]
126
  with dnnlib.util.open_url(netpath, verbose=1) as f:
127
  data = pickle.load(f)
@@ -198,7 +204,7 @@ def test_runner(device="cpu"):
198
  image = np.array(image)
199
  image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
200
  x = torch.from_numpy(image).unsqueeze(0).to(device)
201
- model = build_model(device=device)
202
  scores = model(x)
203
 
204
  return scores
@@ -211,8 +217,8 @@ def test_flow_runner(preset, device="cpu", load_weights=None):
211
  image = np.array(image)
212
  image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
213
  x = torch.from_numpy(image).unsqueeze(0).to(device)
214
-
215
- score_flow = ScoreFlow(preset, device=device)
216
 
217
  if load_weights is not None:
218
  score_flow.flow.load_state_dict(torch.load(load_weights))
@@ -272,7 +278,7 @@ def cache_score_norms(preset, dataset_path, outdir):
272
  dsobj, batch_size=64, num_workers=4, prefetch_factor=2
273
  )
274
 
275
- model = build_model(preset=preset, device=device)
276
  score_norms = []
277
 
278
  for x, _ in tqdm(dsloader):
@@ -312,6 +318,14 @@ def cache_score_norms(preset, dataset_path, outdir):
312
  default="edm2-img64-s-fid",
313
  show_default=True,
314
  )
 
 
 
 
 
 
 
 
315
  @click.option(
316
  "--num_flows",
317
  help="Number of normalizing flow functions in the PatchFlow model",
@@ -320,7 +334,7 @@ def cache_score_norms(preset, dataset_path, outdir):
320
  default=4,
321
  show_default=True,
322
  )
323
- def train_flow(dataset_path, preset, outdir, epochs=10, **flow_kwargs):
324
  print("using device:", DEVICE)
325
  device = DEVICE
326
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
@@ -345,7 +359,8 @@ def train_flow(dataset_path, preset, outdir, epochs=10, **flow_kwargs):
345
  val_ds, batch_size=128, num_workers=4, prefetch_factor=2
346
  )
347
 
348
- model = ScoreFlow(preset, device=device, **flow_kwargs)
 
349
  opt = torch.optim.AdamW(model.flow.parameters(), lr=3e-4, weight_decay=1e-5)
350
  train_step = partial(
351
  PatchFlow.stochastic_step,
@@ -373,6 +388,7 @@ def train_flow(dataset_path, preset, outdir, epochs=10, **flow_kwargs):
373
  step = 0
374
 
375
  for e in pbar:
 
376
  for x, _ in trainiter:
377
  x = x.to(device)
378
  scores = model.scorenet(x)
@@ -411,13 +427,14 @@ def train_flow(dataset_path, preset, outdir, epochs=10, **flow_kwargs):
411
  # Squeeze the juice
412
  best_ckpt = torch.load(f"{experiment_dir}/flow.pt")
413
  model.flow.load_state_dict(best_ckpt)
414
- for i, (x, _) in enumerate(testiter):
 
415
  x = x.to(device)
416
  scores = model.scorenet(x)
417
  train_loss = train_step(scores, x)
418
  writer.add_scalar("loss/train", train_loss, step)
419
  pbar.set_description(
420
- f"(Tuning) Step: {step:d} - Train: {train_loss:.3f} - Val: {val_loss:.3f}"
421
  )
422
  step += 1
423
 
 
19
 
20
  import dnnlib
21
  from dataset import ImageFolderDataset
22
+ from flowutils import PatchFlow, sanitize_locals
23
 
24
  DEVICE: Literal["cuda", "cpu"] = 'cpu'
25
  model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
 
53
  sigma_max=80, # Maximum supported noise level.
54
  sigma_data=0.5, # Expected standard deviation of the training data.
55
  rho=7, # Time step discretization.
 
56
  ):
57
  super().__init__()
58
+
59
+ self.config = sanitize_locals(locals(), ignore_keys='net')
60
+ self.config['EDMNet'] = dict(net.init_kwargs)
61
+
62
  self.use_fp16 = use_fp16
63
  self.sigma_min = sigma_min
64
  self.sigma_max = sigma_max
 
70
  self.sigma_min = 1e-1
71
  self.sigma_max = sigma_max * stop_ratio
72
 
73
+ step_indices = torch.arange(num_steps, dtype=torch.float64)
74
  t_steps = (
75
  self.sigma_max ** (1 / rho)
76
  + step_indices
77
  / (num_steps - 1)
78
  * (self.sigma_min ** (1 / rho) - self.sigma_max ** (1 / rho))
79
  ) ** rho
 
80
 
81
  self.register_buffer("sigma_steps", t_steps.to(torch.float64))
82
 
 
102
  class ScoreFlow(torch.nn.Module):
103
  def __init__(
104
  self,
105
+ scorenet,
106
  device="cpu",
107
  **flow_kwargs
108
  ):
109
  super().__init__()
110
 
 
111
  h = w = scorenet.net.img_resolution
112
  c = scorenet.net.img_channels
113
  num_sigmas = len(scorenet.sigma_steps)
114
  self.flow = PatchFlow((num_sigmas, c, h, w), **flow_kwargs)
115
+
116
 
117
  self.flow = self.flow.to(device)
118
  self.scorenet = scorenet.to(device).requires_grad_(False)
119
  self.flow.init_weights()
120
 
121
+ self.config = dict()
122
+ self.config.update(**self.scorenet.config)
123
+ self.config.update(self.flow.config)
124
+
125
  def forward(self, x, **score_kwargs):
126
  x_scores = self.scorenet(x, **score_kwargs)
127
  return self.flow(x_scores)
128
 
129
 
130
+ def build_model_from_pickle(preset="edm2-img64-s-fid", device="cpu"):
131
  netpath = config_presets[preset]
132
  with dnnlib.util.open_url(netpath, verbose=1) as f:
133
  data = pickle.load(f)
 
204
  image = np.array(image)
205
  image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
206
  x = torch.from_numpy(image).unsqueeze(0).to(device)
207
+ model = build_model_from_pickle(device=device)
208
  scores = model(x)
209
 
210
  return scores
 
217
  image = np.array(image)
218
  image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
219
  x = torch.from_numpy(image).unsqueeze(0).to(device)
220
+ scorenet = build_model_from_pickle(preset)
221
+ score_flow = ScoreFlow(scorenet, device=device)
222
 
223
  if load_weights is not None:
224
  score_flow.flow.load_state_dict(torch.load(load_weights))
 
278
  dsobj, batch_size=64, num_workers=4, prefetch_factor=2
279
  )
280
 
281
+ model = build_model_from_pickle(preset=preset, device=device)
282
  score_norms = []
283
 
284
  for x, _ in tqdm(dsloader):
 
318
  default="edm2-img64-s-fid",
319
  show_default=True,
320
  )
321
+ @click.option(
322
+ "--epochs",
323
+ help="Number of epochs",
324
+ metavar="INT",
325
+ type=int,
326
+ default=10,
327
+ show_default=True,
328
+ )
329
  @click.option(
330
  "--num_flows",
331
  help="Number of normalizing flow functions in the PatchFlow model",
 
334
  default=4,
335
  show_default=True,
336
  )
337
+ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
338
  print("using device:", DEVICE)
339
  device = DEVICE
340
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
 
359
  val_ds, batch_size=128, num_workers=4, prefetch_factor=2
360
  )
361
 
362
+ scorenet = build_model_from_pickle(preset)
363
+ model = ScoreFlow(scorenet, device=device, **flow_kwargs)
364
  opt = torch.optim.AdamW(model.flow.parameters(), lr=3e-4, weight_decay=1e-5)
365
  train_step = partial(
366
  PatchFlow.stochastic_step,
 
388
  step = 0
389
 
390
  for e in pbar:
391
+
392
  for x, _ in trainiter:
393
  x = x.to(device)
394
  scores = model.scorenet(x)
 
427
  # Squeeze the juice
428
  best_ckpt = torch.load(f"{experiment_dir}/flow.pt")
429
  model.flow.load_state_dict(best_ckpt)
430
+ pbar = tqdm(testiter, desc="(Tuning) Step:? - Loss: ?")
431
+ for x, _ in pbar:
432
  x = x.to(device)
433
  scores = model.scorenet(x)
434
  train_loss = train_step(scores, x)
435
  writer.add_scalar("loss/train", train_loss, step)
436
  pbar.set_description(
437
+ f"(Tuning) Step: {step:d} - Loss: {train_loss:.3f}"
438
  )
439
  step += 1
440