ahsanMah commited on
Commit
8476a9b
·
1 Parent(s): 8933ee4

saving model configs

Browse files
Files changed (2) hide show
  1. flowutils.py +1 -1
  2. msma.py +25 -17
flowutils.py CHANGED
@@ -221,7 +221,7 @@ class PatchFlow(torch.nn.Module):
221
  ):
222
  super().__init__()
223
 
224
- self.config = sanitize_locals(locals(), ignore_keys=input_size)
225
 
226
  num_sigmas, c, h, w = input_size
227
  self.local_pooler = SpatialNormer(
 
221
  ):
222
  super().__init__()
223
 
224
+ self.config = sanitize_locals(locals(), ignore_keys="input_size")
225
 
226
  num_sigmas, c, h, w = input_size
227
  self.local_pooler = SpatialNormer(
msma.py CHANGED
@@ -1,4 +1,5 @@
1
  import datetime
 
2
  import os
3
  import pickle
4
  from functools import partial
@@ -21,7 +22,7 @@ import dnnlib
21
  from dataset import ImageFolderDataset
22
  from flowutils import PatchFlow, sanitize_locals
23
 
24
- DEVICE: Literal["cuda", "cpu"] = 'cpu'
25
  model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
26
 
27
  config_presets = {
@@ -56,8 +57,8 @@ class EDMScorer(torch.nn.Module):
56
  ):
57
  super().__init__()
58
 
59
- self.config = sanitize_locals(locals(), ignore_keys='net')
60
- self.config['EDMNet'] = dict(net.init_kwargs)
61
 
62
  self.use_fp16 = use_fp16
63
  self.sigma_min = sigma_min
@@ -100,19 +101,13 @@ class EDMScorer(torch.nn.Module):
100
 
101
 
102
  class ScoreFlow(torch.nn.Module):
103
- def __init__(
104
- self,
105
- scorenet,
106
- device="cpu",
107
- **flow_kwargs
108
- ):
109
  super().__init__()
110
 
111
  h = w = scorenet.net.img_resolution
112
  c = scorenet.net.img_channels
113
  num_sigmas = len(scorenet.sigma_steps)
114
  self.flow = PatchFlow((num_sigmas, c, h, w), **flow_kwargs)
115
-
116
 
117
  self.flow = self.flow.to(device)
118
  self.scorenet = scorenet.to(device).requires_grad_(False)
@@ -265,7 +260,6 @@ def cmdline():
265
  type=str,
266
  required=True,
267
  )
268
-
269
  def cache_score_norms(preset, dataset_path, outdir):
270
  device = DEVICE
271
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
@@ -353,7 +347,7 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
353
  val_ds = Subset(dsobj, range(train_len, train_len + val_len))
354
 
355
  trainiter = torch.utils.data.DataLoader(
356
- train_ds, batch_size=64, num_workers=4, prefetch_factor=2
357
  )
358
  testiter = torch.utils.data.DataLoader(
359
  val_ds, batch_size=128, num_workers=4, prefetch_factor=2
@@ -383,6 +377,9 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
383
  timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M")
384
  writer = SummaryWriter(f"{experiment_dir}/logs/{timestamp}")
385
 
 
 
 
386
  # totaliters = int(epochs * train_len)
387
  pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
388
  step = 0
@@ -398,8 +395,17 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
398
  val_loss = eval_step(scores, x)
399
 
400
  # Log details about model
401
- writer.add_graph(model.flow.flows, (torch.zeros(1, scores.shape[1], device=device),
402
- torch.zeros(1, model.flow.position_encoding.cached_penc.shape[-1], device=device)))
 
 
 
 
 
 
 
 
 
403
 
404
  train_loss = train_step(scores, x)
405
 
@@ -433,12 +439,14 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
433
  scores = model.scorenet(x)
434
  train_loss = train_step(scores, x)
435
  writer.add_scalar("loss/train", train_loss, step)
436
- pbar.set_description(
437
- f"(Tuning) Step: {step:d} - Loss: {train_loss:.3f}"
438
- )
439
  step += 1
440
 
 
441
  torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
 
 
 
442
  writer.close()
443
 
444
 
 
1
  import datetime
2
+ import json
3
  import os
4
  import pickle
5
  from functools import partial
 
22
  from dataset import ImageFolderDataset
23
  from flowutils import PatchFlow, sanitize_locals
24
 
25
+ DEVICE: Literal["cuda", "cpu"] = "cpu"
26
  model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
27
 
28
  config_presets = {
 
57
  ):
58
  super().__init__()
59
 
60
+ self.config = sanitize_locals(locals(), ignore_keys="net")
61
+ self.config["EDMNet"] = dict(net.init_kwargs)
62
 
63
  self.use_fp16 = use_fp16
64
  self.sigma_min = sigma_min
 
101
 
102
 
103
  class ScoreFlow(torch.nn.Module):
104
+ def __init__(self, scorenet, device="cpu", **flow_kwargs):
 
 
 
 
 
105
  super().__init__()
106
 
107
  h = w = scorenet.net.img_resolution
108
  c = scorenet.net.img_channels
109
  num_sigmas = len(scorenet.sigma_steps)
110
  self.flow = PatchFlow((num_sigmas, c, h, w), **flow_kwargs)
 
111
 
112
  self.flow = self.flow.to(device)
113
  self.scorenet = scorenet.to(device).requires_grad_(False)
 
260
  type=str,
261
  required=True,
262
  )
 
263
  def cache_score_norms(preset, dataset_path, outdir):
264
  device = DEVICE
265
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
 
347
  val_ds = Subset(dsobj, range(train_len, train_len + val_len))
348
 
349
  trainiter = torch.utils.data.DataLoader(
350
+ train_ds, batch_size=64, num_workers=4, prefetch_factor=2, shuffle=True
351
  )
352
  testiter = torch.utils.data.DataLoader(
353
  val_ds, batch_size=128, num_workers=4, prefetch_factor=2
 
377
  timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M")
378
  writer = SummaryWriter(f"{experiment_dir}/logs/{timestamp}")
379
 
380
+ with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
381
+ json.dump(model.config, f, sort_keys=True, indent=4)
382
+
383
  # totaliters = int(epochs * train_len)
384
  pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
385
  step = 0
 
395
  val_loss = eval_step(scores, x)
396
 
397
  # Log details about model
398
+ writer.add_graph(
399
+ model.flow.flows,
400
+ (
401
+ torch.zeros(1, scores.shape[1], device=device),
402
+ torch.zeros(
403
+ 1,
404
+ model.flow.position_encoding.cached_penc.shape[-1],
405
+ device=device,
406
+ ),
407
+ ),
408
+ )
409
 
410
  train_loss = train_step(scores, x)
411
 
 
439
  scores = model.scorenet(x)
440
  train_loss = train_step(scores, x)
441
  writer.add_scalar("loss/train", train_loss, step)
442
+ pbar.set_description(f"(Tuning) Step: {step:d} - Loss: {train_loss:.3f}")
 
 
443
  step += 1
444
 
445
+ # Save final model
446
  torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
447
+ with open(f"{experiment_dir}/config.json", "w") as f:
448
+ json.dump(model.config, f, sort_keys=True, indent=4)
449
+
450
  writer.close()
451
 
452