ahsanMah commited on
Commit
d71875f
·
1 Parent(s): bf573cf

plotting a blended version of the heatmap

Browse files
Files changed (1) hide show
  1. app.py +25 -13
app.py CHANGED
@@ -4,6 +4,7 @@ from pickle import load
4
  import gradio as gr
5
  import matplotlib.pyplot as plt
6
  import numpy as np
 
7
  import torch
8
 
9
  from msma import ScoreFlow, config_presets
@@ -39,39 +40,50 @@ def plot_against_reference(nll, ref_nll):
39
  fig.tight_layout()
40
  return fig
41
 
42
-
43
- def plot_heatmap(heatmap):
44
  fig, ax = plt.subplots()
45
- im = heatmap[0,0]
46
- ax.imshow(im, cmap='gist_heat')
 
 
 
 
 
 
 
 
 
 
47
  fig.tight_layout()
48
  return fig
49
 
50
- # def compute_scores
51
-
52
 
53
- def run_inference(img, preset="edm2-img64-s-fid", device="cuda"):
 
54
 
55
  with torch.inference_mode():
 
56
  img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
57
- img = torch.nn.functional.interpolate(img, size=64, mode='bilinear')
58
- img = img.to(device)
59
  model = load_model(modeldir='models', preset=preset, device=device)
 
 
 
60
  x = model.scorenet(img)
61
  x = x.square().sum(dim=(2, 3, 4)) ** 0.5
62
- img_likelihood = model(img).cpu().numpy()
63
  nll, pct, ref_nll = compute_gmm_likelihood(x.cpu(), model_dir=f"models/{preset}")
64
-
65
  outstr = f"Anomaly score: {nll:.3f} / {pct:.2f} percentile"
66
  histplot = plot_against_reference(nll, ref_nll)
67
- heatmapplot = plot_heatmap(img_likelihood)
68
 
69
  return outstr, heatmapplot, histplot
70
 
71
 
72
  demo = gr.Interface(
73
  fn=run_inference,
74
- inputs=["image"],
75
  outputs=["text",
76
  gr.Plot(label="Anomaly Heatmap"),
77
  gr.Plot(label="Comparing to Imagenette"),
 
4
  import gradio as gr
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
+ import PIL.Image as Image
8
  import torch
9
 
10
  from msma import ScoreFlow, config_presets
 
40
  fig.tight_layout()
41
  return fig
42
 
43
+ def plot_heatmap(img: Image, heatmap: np.array):
 
44
  fig, ax = plt.subplots()
45
+ cmap = plt.get_cmap("gist_heat")
46
+ h = heatmap[0,0].copy()
47
+ qmin, qmax = np.quantile(h, 0.5), np.quantile(h, 0.999)
48
+ h = np.clip(h, a_min=qmin, a_max=qmax)
49
+ h = (h-h.min()) / (h.max() - h.min())
50
+ h = cmap(h, bytes=True)[:,:,:3]
51
+ h = Image.fromarray(h).resize(img.size, resample=Image.Resampling.BILINEAR)
52
+ im = Image.blend(img, h, alpha=0.6)
53
+ im = ax.imshow(np.array(im))
54
+ # fig.colorbar(im)
55
+ # plt.grid(False)
56
+ # plt.axis("off")
57
  fig.tight_layout()
58
  return fig
59
 
60
+ def run_inference(input_img, preset="edm2-img64-s-fid", device="cuda"):
 
61
 
62
+ # img = center_crop_imagenet(64, img)
63
+ input_img = input_img.resize(size=(64, 64), resample=Image.Resampling.LANCZOS)
64
 
65
  with torch.inference_mode():
66
+ img = np.array(input_img)
67
  img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
68
+ img = img.float().to(device)
 
69
  model = load_model(modeldir='models', preset=preset, device=device)
70
+ img_likelihood = model(img).cpu().numpy()
71
+
72
+ img = torch.nn.functional.interpolate(img, size=64, mode='bilinear')
73
  x = model.scorenet(img)
74
  x = x.square().sum(dim=(2, 3, 4)) ** 0.5
 
75
  nll, pct, ref_nll = compute_gmm_likelihood(x.cpu(), model_dir=f"models/{preset}")
76
+
77
  outstr = f"Anomaly score: {nll:.3f} / {pct:.2f} percentile"
78
  histplot = plot_against_reference(nll, ref_nll)
79
+ heatmapplot = plot_heatmap(input_img, img_likelihood)
80
 
81
  return outstr, heatmapplot, histplot
82
 
83
 
84
  demo = gr.Interface(
85
  fn=run_inference,
86
+ inputs=[gr.Image(type='pil', label="Input Image")],
87
  outputs=["text",
88
  gr.Plot(label="Anomaly Heatmap"),
89
  gr.Plot(label="Comparing to Imagenette"),