ahsanMah commited on
Commit
95a02fd
·
1 Parent(s): b55c3d2

added support for loading models from HF hub

Browse files
Files changed (1) hide show
  1. app.py +38 -10
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from functools import cache
2
  from pickle import load
3
 
@@ -6,14 +7,42 @@ import matplotlib.pyplot as plt
6
  import numpy as np
7
  import PIL.Image as Image
8
  import torch
 
 
9
 
10
- from msma import ScoreFlow, config_presets
11
 
12
 
13
  @cache
14
- def load_model(modeldir, preset="edm2-img64-s-fid", device="cpu", outdir=None):
15
- model = ScoreFlow(preset, device=device)
16
- model.flow.load_state_dict(torch.load(f"{modeldir}/{preset}/flow.pt"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  return model
18
 
19
 
@@ -62,8 +91,9 @@ def plot_heatmap(img: Image, heatmap: np.array):
62
  return im
63
 
64
 
65
- def run_inference(input_img, preset="edm2-img64-s-fid", device="cuda"):
66
 
 
67
  # img = center_crop_imagenet(64, img)
68
  input_img = input_img.resize(size=(64, 64), resample=Image.Resampling.LANCZOS)
69
 
@@ -71,7 +101,8 @@ def run_inference(input_img, preset="edm2-img64-s-fid", device="cuda"):
71
  img = np.array(input_img)
72
  img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
73
  img = img.float().to(device)
74
- model = load_model(modeldir="models", preset=preset, device=device)
 
75
  img_likelihood = model(img).cpu().numpy()
76
  # img_likelihood = model.scorenet(img).square().sum(1).sum(1).contiguous().float().cpu().unsqueeze(1).numpy()
77
  # print(img_likelihood.shape, img_likelihood.dtype)
@@ -100,10 +131,7 @@ demo = gr.Interface(
100
  gr.Image(label="Anomaly Heatmap", min_width=64),
101
  gr.Plot(label="Comparing to Imagenette"),
102
  ],
103
-
104
- examples=[
105
- ['goldfish.JPEG', "edm2-img64-s-fid"]
106
- ]
107
  )
108
 
109
  if __name__ == "__main__":
 
1
+ import json
2
  from functools import cache
3
  from pickle import load
4
 
 
7
  import numpy as np
8
  import PIL.Image as Image
9
  import torch
10
+ from huggingface_hub import hf_hub_download
11
+ from safetensors.torch import load_file
12
 
13
+ from msma import ScoreFlow, build_model_from_pickle, config_presets
14
 
15
 
16
  @cache
17
+ def load_model(modeldir, preset="edm2-img64-s-fid", device="cpu"):
18
+ model = ScoreFlow(preset, num_flows=8, device=device)
19
+ model.flow.load_state_dict(torch.load(f"{modeldir}/nb8/{preset}/flow.pt"))
20
+ return model
21
+
22
+ @cache
23
+ def load_model_from_hub(preset, device):
24
+ scorenet = build_model_from_pickle(preset)
25
+
26
+ hf_config = hf_hub_download(
27
+ repo_id="ahsanMah/localizing-edm",
28
+ subfolder=preset,
29
+ filename="config.json",
30
+ cache_dir="/tmp/",
31
+ )
32
+ with open(hf_config, "rb") as f:
33
+ model_params = json.load(f)
34
+ print("Loaded:", model_params)
35
+
36
+ hf_checkpoint = hf_hub_download(
37
+ repo_id="ahsanMah/localizing-edm",
38
+ subfolder=preset,
39
+ filename="model.safetensors",
40
+ cache_dir="/tmp/",
41
+ )
42
+
43
+ model = ScoreFlow(scorenet, device=device, **model_params['PatchFlow'])
44
+ model.load_state_dict(load_file(hf_checkpoint), strict=True)
45
+
46
  return model
47
 
48
 
 
91
  return im
92
 
93
 
94
+ def run_inference(input_img, preset="edm2-img64-s-fid"):
95
 
96
+ device = "cuda" if torch.cuda.is_available() else "cpu"
97
  # img = center_crop_imagenet(64, img)
98
  input_img = input_img.resize(size=(64, 64), resample=Image.Resampling.LANCZOS)
99
 
 
101
  img = np.array(input_img)
102
  img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
103
  img = img.float().to(device)
104
+ # model = load_model(modeldir="models", preset=preset, device=device)
105
+ model = load_model_from_hub(preset=preset, device=device)
106
  img_likelihood = model(img).cpu().numpy()
107
  # img_likelihood = model.scorenet(img).square().sum(1).sum(1).contiguous().float().cpu().unsqueeze(1).numpy()
108
  # print(img_likelihood.shape, img_likelihood.dtype)
 
131
  gr.Image(label="Anomaly Heatmap", min_width=64),
132
  gr.Plot(label="Comparing to Imagenette"),
133
  ],
134
+ examples=[["goldfish.JPEG", "edm2-img64-s-fid"]],
 
 
 
135
  )
136
 
137
  if __name__ == "__main__":