ahsanMah commited on
Commit
a186356
·
1 Parent(s): 1b96548

using groups for command line options

Browse files
Files changed (1) hide show
  1. msma.py +142 -112
msma.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import pickle
4
  from functools import partial
5
  from pickle import dump, load
 
6
 
7
  import click
8
  import numpy as np
@@ -20,6 +21,7 @@ import dnnlib
20
  from dataset import ImageFolderDataset
21
  from flowutils import PatchFlow
22
 
 
23
  model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
24
 
25
  config_presets = {
@@ -100,6 +102,7 @@ class ScoreFlow(torch.nn.Module):
100
  self,
101
  preset,
102
  device="cpu",
 
103
  ):
104
  super().__init__()
105
 
@@ -107,7 +110,7 @@ class ScoreFlow(torch.nn.Module):
107
  h = w = scorenet.net.img_resolution
108
  c = scorenet.net.img_channels
109
  num_sigmas = len(scorenet.sigma_steps)
110
- self.flow = PatchFlow((num_sigmas, c, h, w))
111
 
112
  self.flow = self.flow.to(device)
113
  self.scorenet = scorenet.to(device).requires_grad_(False)
@@ -187,7 +190,78 @@ def compute_gmm_likelihood(x_score, gmmdir):
187
  return nll, percentile
188
 
189
 
190
- def cache_score_norms(preset, dataset_path, outdir, device="cpu"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
192
  refimg, reflabel = dsobj[0]
193
  print(f"Loading dataset from {dataset_path}")
@@ -215,7 +289,40 @@ def cache_score_norms(preset, dataset_path, outdir, device="cpu"):
215
  print(f"Computed score norms for {score_norms.shape[0]} samples")
216
 
217
 
218
- def train_flow(dataset_path, preset, outdir, epochs=10, device="cuda"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
220
  refimg, reflabel = dsobj[0]
221
  print(f"Loaded {len(dsobj)} samples from {dataset_path}")
@@ -238,7 +345,7 @@ def train_flow(dataset_path, preset, outdir, epochs=10, device="cuda"):
238
  val_ds, batch_size=128, num_workers=4, prefetch_factor=2
239
  )
240
 
241
- model = ScoreFlow(preset, device=device)
242
  opt = torch.optim.AdamW(model.flow.parameters(), lr=3e-4, weight_decay=1e-5)
243
  train_step = partial(
244
  PatchFlow.stochastic_step,
@@ -274,6 +381,10 @@ def train_flow(dataset_path, preset, outdir, epochs=10, device="cuda"):
274
  with torch.inference_mode():
275
  val_loss = eval_step(scores, x)
276
 
 
 
 
 
277
  train_loss = train_step(scores, x)
278
 
279
  if (step + 1) % 10 == 0:
@@ -297,117 +408,36 @@ def train_flow(dataset_path, preset, outdir, epochs=10, device="cuda"):
297
  )
298
  step += 1
299
 
300
- # torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
301
- writer.close()
302
-
303
-
304
- @torch.inference_mode
305
- def test_runner(device="cpu"):
306
- # f = "doge.jpg"
307
- f = "goldfish.JPEG"
308
- image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
309
- image = np.array(image)
310
- image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
311
- x = torch.from_numpy(image).unsqueeze(0).to(device)
312
- model = build_model(device=device)
313
- scores = model(x)
314
-
315
- return scores
316
-
317
-
318
- def test_flow_runner(preset, device="cpu", load_weights=None):
319
- # f = "doge.jpg"
320
- f = "goldfish.JPEG"
321
- image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
322
- image = np.array(image)
323
- image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
324
- x = torch.from_numpy(image).unsqueeze(0).to(device)
325
-
326
- score_flow = ScoreFlow(preset, device=device)
327
-
328
- if load_weights is not None:
329
- score_flow.flow.load_state_dict(torch.load(load_weights))
330
-
331
- heatmap = score_flow(x)
332
- print(heatmap.shape)
333
-
334
- heatmap = score_flow(x).detach().cpu().numpy()
335
- heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) * 255
336
- im = PIL.Image.fromarray(heatmap[0, 0])
337
- im.convert("RGB").save(
338
- "heatmap.png",
339
- )
340
-
341
- return
342
 
 
 
343
 
344
- @click.command()
345
 
346
- # Main options.
347
- @click.option(
348
- "--run",
349
- help="Which function to run",
350
- type=click.Choice(
351
- ["cache-scores", "train-flow", "train-gmm"], case_sensitive=False
352
- ),
353
- )
354
- @click.option(
355
- "--outdir",
356
- help="Where to load/save the results",
357
- metavar="DIR",
358
- type=str,
359
- required=True,
360
- )
361
- @click.option(
362
- "--preset",
363
- help="Configuration preset",
364
- metavar="STR",
365
- type=str,
366
- default="edm2-img64-s-fid",
367
- show_default=True,
368
- )
369
- @click.option(
370
- "--data", help="Path to the dataset", metavar="ZIP|DIR", type=str, default=None
371
- )
372
- def cmdline(run, outdir, **opts):
373
- device = "cuda" if torch.cuda.is_available() else "cpu"
374
- preset = opts["preset"]
375
- dataset_path = opts["data"]
376
-
377
- if run in ["cache-scores", "train-flow"]:
378
- assert opts["data"] is not None, "Provide path to dataset"
379
-
380
- if run == "cache-scores":
381
- cache_score_norms(
382
- preset=preset, dataset_path=dataset_path, outdir=outdir, device=device
383
- )
384
-
385
- if run == "train-gmm":
386
- train_gmm(
387
- score_path=f"{outdir}/{preset}/imagenette_score_norms.pt",
388
- outdir=f"{outdir}/{preset}",
389
- grid_search=True,
390
- )
391
-
392
- if run == "train-flow":
393
- train_flow(dataset_path, outdir=outdir, preset=preset, device=device)
394
- test_flow_runner(preset, device=device, load_weights=f"{outdir}/{preset}/flow.pt")
395
-
396
- # train_flow(imagenette_path, preset, device)
397
-
398
- # cache_score_norms(
399
- # preset=preset,
400
- # dataset_path="/GROND_STOR/amahmood/datasets/img64/",
401
- # device="cuda",
402
- # )
403
- # train_gmm(
404
- # f"out/msma/{preset}_imagenette_score_norms.pt", outdir=f"out/msma/{preset}"
405
- # )
406
- # s = test_runner(device=device)
407
- # s = s.square().sum(dim=(2, 3, 4)) ** 0.5
408
- # s = s.to("cpu").numpy()
409
- # nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/")
410
- # print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
411
 
412
 
413
  if __name__ == "__main__":
 
3
  import pickle
4
  from functools import partial
5
  from pickle import dump, load
6
+ from typing import Literal
7
 
8
  import click
9
  import numpy as np
 
21
  from dataset import ImageFolderDataset
22
  from flowutils import PatchFlow
23
 
24
+ DEVICE: Literal["cuda", "cpu"] = 'cpu'
25
  model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
26
 
27
  config_presets = {
 
102
  self,
103
  preset,
104
  device="cpu",
105
+ **flow_kwargs
106
  ):
107
  super().__init__()
108
 
 
110
  h = w = scorenet.net.img_resolution
111
  c = scorenet.net.img_channels
112
  num_sigmas = len(scorenet.sigma_steps)
113
+ self.flow = PatchFlow((num_sigmas, c, h, w), **flow_kwargs)
114
 
115
  self.flow = self.flow.to(device)
116
  self.scorenet = scorenet.to(device).requires_grad_(False)
 
190
  return nll, percentile
191
 
192
 
193
+ @torch.inference_mode
194
+ def test_runner(device="cpu"):
195
+ # f = "doge.jpg"
196
+ f = "goldfish.JPEG"
197
+ image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
198
+ image = np.array(image)
199
+ image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
200
+ x = torch.from_numpy(image).unsqueeze(0).to(device)
201
+ model = build_model(device=device)
202
+ scores = model(x)
203
+
204
+ return scores
205
+
206
+
207
+ def test_flow_runner(preset, device="cpu", load_weights=None):
208
+ # f = "doge.jpg"
209
+ f = "goldfish.JPEG"
210
+ image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
211
+ image = np.array(image)
212
+ image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
213
+ x = torch.from_numpy(image).unsqueeze(0).to(device)
214
+
215
+ score_flow = ScoreFlow(preset, device=device)
216
+
217
+ if load_weights is not None:
218
+ score_flow.flow.load_state_dict(torch.load(load_weights))
219
+
220
+ heatmap = score_flow(x)
221
+ print(heatmap.shape)
222
+
223
+ heatmap = score_flow(x).detach().cpu().numpy()
224
+ heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) * 255
225
+ im = PIL.Image.fromarray(heatmap[0, 0])
226
+ im.convert("RGB").save(
227
+ "heatmap.png",
228
+ )
229
+
230
+ return
231
+
232
+
233
+ @click.group()
234
+ def cmdline():
235
+ global DEVICE
236
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
237
+
238
+
239
+ @cmdline.command(name="cache-scores")
240
+ @click.option(
241
+ "--preset",
242
+ help="Configuration preset",
243
+ metavar="STR",
244
+ type=str,
245
+ default="edm2-img64-s-fid",
246
+ show_default=True,
247
+ )
248
+ @click.option(
249
+ "--dataset_path",
250
+ help="Path to the dataset",
251
+ metavar="ZIP|DIR",
252
+ type=str,
253
+ default=None,
254
+ )
255
+ @click.option(
256
+ "--outdir",
257
+ help="Where to load/save the results",
258
+ metavar="DIR",
259
+ type=str,
260
+ required=True,
261
+ )
262
+
263
+ def cache_score_norms(preset, dataset_path, outdir):
264
+ device = DEVICE
265
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
266
  refimg, reflabel = dsobj[0]
267
  print(f"Loading dataset from {dataset_path}")
 
289
  print(f"Computed score norms for {score_norms.shape[0]} samples")
290
 
291
 
292
+ @cmdline.command(name="train-flow")
293
+ @click.option(
294
+ "--dataset_path",
295
+ help="Path to the dataset",
296
+ metavar="ZIP|DIR",
297
+ type=str,
298
+ default=None,
299
+ )
300
+ @click.option(
301
+ "--outdir",
302
+ help="Where to load/save the results",
303
+ metavar="DIR",
304
+ type=str,
305
+ required=True,
306
+ )
307
+ @click.option(
308
+ "--preset",
309
+ help="Configuration preset",
310
+ metavar="STR",
311
+ type=str,
312
+ default="edm2-img64-s-fid",
313
+ show_default=True,
314
+ )
315
+ @click.option(
316
+ "--num_flows",
317
+ help="Number of normalizing flow functions in the PatchFlow model",
318
+ metavar="INT",
319
+ type=int,
320
+ default=4,
321
+ show_default=True,
322
+ )
323
+ def train_flow(dataset_path, preset, outdir, epochs=10, **flow_kwargs):
324
+ print("using device:", DEVICE)
325
+ device = DEVICE
326
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
327
  refimg, reflabel = dsobj[0]
328
  print(f"Loaded {len(dsobj)} samples from {dataset_path}")
 
345
  val_ds, batch_size=128, num_workers=4, prefetch_factor=2
346
  )
347
 
348
+ model = ScoreFlow(preset, device=device, **flow_kwargs)
349
  opt = torch.optim.AdamW(model.flow.parameters(), lr=3e-4, weight_decay=1e-5)
350
  train_step = partial(
351
  PatchFlow.stochastic_step,
 
381
  with torch.inference_mode():
382
  val_loss = eval_step(scores, x)
383
 
384
+ # Log details about model
385
+ writer.add_graph(model.flow.flows, (torch.zeros(1, scores.shape[1], device=device),
386
+ torch.zeros(1, model.flow.position_encoding.cached_penc.shape[-1], device=device)))
387
+
388
  train_loss = train_step(scores, x)
389
 
390
  if (step + 1) % 10 == 0:
 
408
  )
409
  step += 1
410
 
411
+ # Squeeze the juice
412
+ best_ckpt = torch.load(f"{experiment_dir}/flow.pt")
413
+ model.flow.load_state_dict(best_ckpt)
414
+ for i, (x, _) in enumerate(testiter):
415
+ x = x.to(device)
416
+ scores = model.scorenet(x)
417
+ train_loss = train_step(scores, x)
418
+ writer.add_scalar("loss/train", train_loss, step)
419
+ pbar.set_description(
420
+ f"(Tuning) Step: {step:d} - Train: {train_loss:.3f} - Val: {val_loss:.3f}"
421
+ )
422
+ step += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
+ torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
425
+ writer.close()
426
 
 
427
 
428
+ # cache_score_norms(
429
+ # preset=preset,
430
+ # dataset_path="/GROND_STOR/amahmood/datasets/img64/",
431
+ # device="cuda",
432
+ # )
433
+ # train_gmm(
434
+ # f"out/msma/{preset}_imagenette_score_norms.pt", outdir=f"out/msma/{preset}"
435
+ # )
436
+ # s = test_runner(device=device)
437
+ # s = s.square().sum(dim=(2, 3, 4)) ** 0.5
438
+ # s = s.to("cpu").numpy()
439
+ # nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/")
440
+ # print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
 
443
  if __name__ == "__main__":