Spaces:
Runtime error
Runtime error
saving model configs
Browse files- flowutils.py +1 -1
- msma.py +25 -17
flowutils.py
CHANGED
@@ -221,7 +221,7 @@ class PatchFlow(torch.nn.Module):
|
|
221 |
):
|
222 |
super().__init__()
|
223 |
|
224 |
-
self.config = sanitize_locals(locals(), ignore_keys=input_size)
|
225 |
|
226 |
num_sigmas, c, h, w = input_size
|
227 |
self.local_pooler = SpatialNormer(
|
|
|
221 |
):
|
222 |
super().__init__()
|
223 |
|
224 |
+
self.config = sanitize_locals(locals(), ignore_keys="input_size")
|
225 |
|
226 |
num_sigmas, c, h, w = input_size
|
227 |
self.local_pooler = SpatialNormer(
|
msma.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import datetime
|
|
|
2 |
import os
|
3 |
import pickle
|
4 |
from functools import partial
|
@@ -21,7 +22,7 @@ import dnnlib
|
|
21 |
from dataset import ImageFolderDataset
|
22 |
from flowutils import PatchFlow, sanitize_locals
|
23 |
|
24 |
-
DEVICE: Literal["cuda", "cpu"] =
|
25 |
model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
|
26 |
|
27 |
config_presets = {
|
@@ -56,8 +57,8 @@ class EDMScorer(torch.nn.Module):
|
|
56 |
):
|
57 |
super().__init__()
|
58 |
|
59 |
-
self.config = sanitize_locals(locals(), ignore_keys=
|
60 |
-
self.config[
|
61 |
|
62 |
self.use_fp16 = use_fp16
|
63 |
self.sigma_min = sigma_min
|
@@ -100,19 +101,13 @@ class EDMScorer(torch.nn.Module):
|
|
100 |
|
101 |
|
102 |
class ScoreFlow(torch.nn.Module):
|
103 |
-
def __init__(
|
104 |
-
self,
|
105 |
-
scorenet,
|
106 |
-
device="cpu",
|
107 |
-
**flow_kwargs
|
108 |
-
):
|
109 |
super().__init__()
|
110 |
|
111 |
h = w = scorenet.net.img_resolution
|
112 |
c = scorenet.net.img_channels
|
113 |
num_sigmas = len(scorenet.sigma_steps)
|
114 |
self.flow = PatchFlow((num_sigmas, c, h, w), **flow_kwargs)
|
115 |
-
|
116 |
|
117 |
self.flow = self.flow.to(device)
|
118 |
self.scorenet = scorenet.to(device).requires_grad_(False)
|
@@ -265,7 +260,6 @@ def cmdline():
|
|
265 |
type=str,
|
266 |
required=True,
|
267 |
)
|
268 |
-
|
269 |
def cache_score_norms(preset, dataset_path, outdir):
|
270 |
device = DEVICE
|
271 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
@@ -353,7 +347,7 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
353 |
val_ds = Subset(dsobj, range(train_len, train_len + val_len))
|
354 |
|
355 |
trainiter = torch.utils.data.DataLoader(
|
356 |
-
train_ds, batch_size=64, num_workers=4, prefetch_factor=2
|
357 |
)
|
358 |
testiter = torch.utils.data.DataLoader(
|
359 |
val_ds, batch_size=128, num_workers=4, prefetch_factor=2
|
@@ -383,6 +377,9 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
383 |
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M")
|
384 |
writer = SummaryWriter(f"{experiment_dir}/logs/{timestamp}")
|
385 |
|
|
|
|
|
|
|
386 |
# totaliters = int(epochs * train_len)
|
387 |
pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
|
388 |
step = 0
|
@@ -398,8 +395,17 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
398 |
val_loss = eval_step(scores, x)
|
399 |
|
400 |
# Log details about model
|
401 |
-
writer.add_graph(
|
402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
|
404 |
train_loss = train_step(scores, x)
|
405 |
|
@@ -433,12 +439,14 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
433 |
scores = model.scorenet(x)
|
434 |
train_loss = train_step(scores, x)
|
435 |
writer.add_scalar("loss/train", train_loss, step)
|
436 |
-
pbar.set_description(
|
437 |
-
f"(Tuning) Step: {step:d} - Loss: {train_loss:.3f}"
|
438 |
-
)
|
439 |
step += 1
|
440 |
|
|
|
441 |
torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
|
|
|
|
|
|
|
442 |
writer.close()
|
443 |
|
444 |
|
|
|
1 |
import datetime
|
2 |
+
import json
|
3 |
import os
|
4 |
import pickle
|
5 |
from functools import partial
|
|
|
22 |
from dataset import ImageFolderDataset
|
23 |
from flowutils import PatchFlow, sanitize_locals
|
24 |
|
25 |
+
DEVICE: Literal["cuda", "cpu"] = "cpu"
|
26 |
model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
|
27 |
|
28 |
config_presets = {
|
|
|
57 |
):
|
58 |
super().__init__()
|
59 |
|
60 |
+
self.config = sanitize_locals(locals(), ignore_keys="net")
|
61 |
+
self.config["EDMNet"] = dict(net.init_kwargs)
|
62 |
|
63 |
self.use_fp16 = use_fp16
|
64 |
self.sigma_min = sigma_min
|
|
|
101 |
|
102 |
|
103 |
class ScoreFlow(torch.nn.Module):
|
104 |
+
def __init__(self, scorenet, device="cpu", **flow_kwargs):
|
|
|
|
|
|
|
|
|
|
|
105 |
super().__init__()
|
106 |
|
107 |
h = w = scorenet.net.img_resolution
|
108 |
c = scorenet.net.img_channels
|
109 |
num_sigmas = len(scorenet.sigma_steps)
|
110 |
self.flow = PatchFlow((num_sigmas, c, h, w), **flow_kwargs)
|
|
|
111 |
|
112 |
self.flow = self.flow.to(device)
|
113 |
self.scorenet = scorenet.to(device).requires_grad_(False)
|
|
|
260 |
type=str,
|
261 |
required=True,
|
262 |
)
|
|
|
263 |
def cache_score_norms(preset, dataset_path, outdir):
|
264 |
device = DEVICE
|
265 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
|
|
347 |
val_ds = Subset(dsobj, range(train_len, train_len + val_len))
|
348 |
|
349 |
trainiter = torch.utils.data.DataLoader(
|
350 |
+
train_ds, batch_size=64, num_workers=4, prefetch_factor=2, shuffle=True
|
351 |
)
|
352 |
testiter = torch.utils.data.DataLoader(
|
353 |
val_ds, batch_size=128, num_workers=4, prefetch_factor=2
|
|
|
377 |
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M")
|
378 |
writer = SummaryWriter(f"{experiment_dir}/logs/{timestamp}")
|
379 |
|
380 |
+
with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
|
381 |
+
json.dump(model.config, f, sort_keys=True, indent=4)
|
382 |
+
|
383 |
# totaliters = int(epochs * train_len)
|
384 |
pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
|
385 |
step = 0
|
|
|
395 |
val_loss = eval_step(scores, x)
|
396 |
|
397 |
# Log details about model
|
398 |
+
writer.add_graph(
|
399 |
+
model.flow.flows,
|
400 |
+
(
|
401 |
+
torch.zeros(1, scores.shape[1], device=device),
|
402 |
+
torch.zeros(
|
403 |
+
1,
|
404 |
+
model.flow.position_encoding.cached_penc.shape[-1],
|
405 |
+
device=device,
|
406 |
+
),
|
407 |
+
),
|
408 |
+
)
|
409 |
|
410 |
train_loss = train_step(scores, x)
|
411 |
|
|
|
439 |
scores = model.scorenet(x)
|
440 |
train_loss = train_step(scores, x)
|
441 |
writer.add_scalar("loss/train", train_loss, step)
|
442 |
+
pbar.set_description(f"(Tuning) Step: {step:d} - Loss: {train_loss:.3f}")
|
|
|
|
|
443 |
step += 1
|
444 |
|
445 |
+
# Save final model
|
446 |
torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
|
447 |
+
with open(f"{experiment_dir}/config.json", "w") as f:
|
448 |
+
json.dump(model.config, f, sort_keys=True, indent=4)
|
449 |
+
|
450 |
writer.close()
|
451 |
|
452 |
|