ahsanMah commited on
Commit
cd21582
·
1 Parent(s): 65245db

format changes

Browse files
Files changed (1) hide show
  1. app.py +39 -23
app.py CHANGED
@@ -11,17 +11,19 @@ from msma import ScoreFlow, config_presets
11
 
12
 
13
  @cache
14
- def load_model(modeldir, preset="edm2-img64-s-fid", device='cpu', outdir=None):
15
  model = ScoreFlow(preset, device=device)
16
  model.flow.load_state_dict(torch.load(f"{modeldir}/{preset}/flow.pt"))
17
  return model
18
 
 
19
  @cache
20
  def load_reference_scores(model_dir):
21
  with np.load(f"{model_dir}/refscores.npz", "rb") as f:
22
  ref_nll = f["arr_0"]
23
  return ref_nll
24
 
 
25
  def compute_gmm_likelihood(x_score, model_dir):
26
  with open(f"{model_dir}/gmm.pkl", "rb") as f:
27
  clf = load(f)
@@ -32,47 +34,53 @@ def compute_gmm_likelihood(x_score, model_dir):
32
 
33
  return nll, percentile, ref_nll
34
 
 
35
  def plot_against_reference(nll, ref_nll):
36
  fig, ax = plt.subplots()
37
  ax.hist(ref_nll, label="Reference Scores")
38
- ax.axvline(nll, label='Image Score', c='red', ls="--")
39
  plt.legend()
40
  fig.tight_layout()
41
  return fig
42
 
 
43
  def plot_heatmap(img: Image, heatmap: np.array):
44
  fig, ax = plt.subplots()
45
  cmap = plt.get_cmap("gist_heat")
46
- h = heatmap[0,0].copy()
47
- qmin, qmax = np.quantile(h, 0.5), np.quantile(h, 0.999)
48
  h = np.clip(h, a_min=qmin, a_max=qmax)
49
- h = (h-h.min()) / (h.max() - h.min())
50
- h = cmap(h, bytes=True)[:,:,:3]
51
  h = Image.fromarray(h).resize(img.size, resample=Image.Resampling.BILINEAR)
52
  im = Image.blend(img, h, alpha=0.6)
53
- im = ax.imshow(np.array(im))
54
- # fig.colorbar(im)
55
- # plt.grid(False)
56
- # plt.axis("off")
57
- fig.tight_layout()
58
- return fig
59
 
60
- def run_inference(input_img, preset="edm2-img64-s-fid", device="cuda"):
 
61
 
62
  # img = center_crop_imagenet(64, img)
63
  input_img = input_img.resize(size=(64, 64), resample=Image.Resampling.LANCZOS)
64
 
65
  with torch.inference_mode():
66
  img = np.array(input_img)
67
- img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
68
  img = img.float().to(device)
69
- model = load_model(modeldir='models', preset=preset, device=device)
70
  img_likelihood = model(img).cpu().numpy()
71
-
72
- img = torch.nn.functional.interpolate(img, size=64, mode='bilinear')
 
73
  x = model.scorenet(img)
74
  x = x.square().sum(dim=(2, 3, 4)) ** 0.5
75
- nll, pct, ref_nll = compute_gmm_likelihood(x.cpu(), model_dir=f"models/{preset}")
 
 
76
 
77
  outstr = f"Anomaly score: {nll:.3f} / {pct:.2f} percentile"
78
  histplot = plot_against_reference(nll, ref_nll)
@@ -83,11 +91,19 @@ def run_inference(input_img, preset="edm2-img64-s-fid", device="cuda"):
83
 
84
  demo = gr.Interface(
85
  fn=run_inference,
86
- inputs=[gr.Image(type='pil', label="Input Image")],
87
- outputs=["text",
88
- gr.Plot(label="Anomaly Heatmap"),
89
- gr.Plot(label="Comparing to Imagenette"),
90
- ],
 
 
 
 
 
 
 
 
91
  )
92
 
93
  if __name__ == "__main__":
 
11
 
12
 
13
  @cache
14
+ def load_model(modeldir, preset="edm2-img64-s-fid", device="cpu", outdir=None):
15
  model = ScoreFlow(preset, device=device)
16
  model.flow.load_state_dict(torch.load(f"{modeldir}/{preset}/flow.pt"))
17
  return model
18
 
19
+
20
  @cache
21
  def load_reference_scores(model_dir):
22
  with np.load(f"{model_dir}/refscores.npz", "rb") as f:
23
  ref_nll = f["arr_0"]
24
  return ref_nll
25
 
26
+
27
  def compute_gmm_likelihood(x_score, model_dir):
28
  with open(f"{model_dir}/gmm.pkl", "rb") as f:
29
  clf = load(f)
 
34
 
35
  return nll, percentile, ref_nll
36
 
37
+
38
  def plot_against_reference(nll, ref_nll):
39
  fig, ax = plt.subplots()
40
  ax.hist(ref_nll, label="Reference Scores")
41
+ ax.axvline(nll, label="Image Score", c="red", ls="--")
42
  plt.legend()
43
  fig.tight_layout()
44
  return fig
45
 
46
+
47
  def plot_heatmap(img: Image, heatmap: np.array):
48
  fig, ax = plt.subplots()
49
  cmap = plt.get_cmap("gist_heat")
50
+ h = -heatmap[0, 0].copy()
51
+ qmin, qmax = np.quantile(h, 0.8), np.quantile(h, 0.999)
52
  h = np.clip(h, a_min=qmin, a_max=qmax)
53
+ h = (h - h.min()) / (h.max() - h.min())
54
+ h = cmap(h, bytes=True)[:, :, :3]
55
  h = Image.fromarray(h).resize(img.size, resample=Image.Resampling.BILINEAR)
56
  im = Image.blend(img, h, alpha=0.6)
57
+ # im = ax.imshow(np.array(im))
58
+ # # fig.colorbar(im)
59
+ # # plt.grid(False)
60
+ # # plt.axis("off")
61
+ # fig.tight_layout()
62
+ return im
63
 
64
+
65
+ def run_inference(input_img, preset="edm2-img64-s-fid", device="cuda"):
66
 
67
  # img = center_crop_imagenet(64, img)
68
  input_img = input_img.resize(size=(64, 64), resample=Image.Resampling.LANCZOS)
69
 
70
  with torch.inference_mode():
71
  img = np.array(input_img)
72
+ img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
73
  img = img.float().to(device)
74
+ model = load_model(modeldir="models", preset=preset, device=device)
75
  img_likelihood = model(img).cpu().numpy()
76
+ # img_likelihood = model.scorenet(img).square().sum(1).sum(1).contiguous().float().cpu().unsqueeze(1).numpy()
77
+ # print(img_likelihood.shape, img_likelihood.dtype)
78
+ img = torch.nn.functional.interpolate(img, size=64, mode="bilinear")
79
  x = model.scorenet(img)
80
  x = x.square().sum(dim=(2, 3, 4)) ** 0.5
81
+ nll, pct, ref_nll = compute_gmm_likelihood(
82
+ x.cpu(), model_dir=f"models/{preset}"
83
+ )
84
 
85
  outstr = f"Anomaly score: {nll:.3f} / {pct:.2f} percentile"
86
  histplot = plot_against_reference(nll, ref_nll)
 
91
 
92
  demo = gr.Interface(
93
  fn=run_inference,
94
+ inputs=[
95
+ gr.Image(type="pil", label="Input Image"),
96
+ gr.Dropdown(choices=config_presets.keys(), label="Score Model"),
97
+ ],
98
+ outputs=[
99
+ "text",
100
+ gr.Image(label="Anomaly Heatmap", min_width=64),
101
+ gr.Plot(label="Comparing to Imagenette"),
102
+ ],
103
+
104
+ examples=[
105
+ ['goldfish.JPEG', "edm2-img64-s-fid"]
106
+ ]
107
  )
108
 
109
  if __name__ == "__main__":