Spaces:
Runtime error
Runtime error
plotting a blended version of the heatmap
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ from pickle import load
|
|
4 |
import gradio as gr
|
5 |
import matplotlib.pyplot as plt
|
6 |
import numpy as np
|
|
|
7 |
import torch
|
8 |
|
9 |
from msma import ScoreFlow, config_presets
|
@@ -39,39 +40,50 @@ def plot_against_reference(nll, ref_nll):
|
|
39 |
fig.tight_layout()
|
40 |
return fig
|
41 |
|
42 |
-
|
43 |
-
def plot_heatmap(heatmap):
|
44 |
fig, ax = plt.subplots()
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
fig.tight_layout()
|
48 |
return fig
|
49 |
|
50 |
-
|
51 |
-
|
52 |
|
53 |
-
|
|
|
54 |
|
55 |
with torch.inference_mode():
|
|
|
56 |
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
|
57 |
-
img =
|
58 |
-
img = img.to(device)
|
59 |
model = load_model(modeldir='models', preset=preset, device=device)
|
|
|
|
|
|
|
60 |
x = model.scorenet(img)
|
61 |
x = x.square().sum(dim=(2, 3, 4)) ** 0.5
|
62 |
-
img_likelihood = model(img).cpu().numpy()
|
63 |
nll, pct, ref_nll = compute_gmm_likelihood(x.cpu(), model_dir=f"models/{preset}")
|
64 |
-
|
65 |
outstr = f"Anomaly score: {nll:.3f} / {pct:.2f} percentile"
|
66 |
histplot = plot_against_reference(nll, ref_nll)
|
67 |
-
heatmapplot = plot_heatmap(img_likelihood)
|
68 |
|
69 |
return outstr, heatmapplot, histplot
|
70 |
|
71 |
|
72 |
demo = gr.Interface(
|
73 |
fn=run_inference,
|
74 |
-
inputs=["
|
75 |
outputs=["text",
|
76 |
gr.Plot(label="Anomaly Heatmap"),
|
77 |
gr.Plot(label="Comparing to Imagenette"),
|
|
|
4 |
import gradio as gr
|
5 |
import matplotlib.pyplot as plt
|
6 |
import numpy as np
|
7 |
+
import PIL.Image as Image
|
8 |
import torch
|
9 |
|
10 |
from msma import ScoreFlow, config_presets
|
|
|
40 |
fig.tight_layout()
|
41 |
return fig
|
42 |
|
43 |
+
def plot_heatmap(img: Image, heatmap: np.array):
|
|
|
44 |
fig, ax = plt.subplots()
|
45 |
+
cmap = plt.get_cmap("gist_heat")
|
46 |
+
h = heatmap[0,0].copy()
|
47 |
+
qmin, qmax = np.quantile(h, 0.5), np.quantile(h, 0.999)
|
48 |
+
h = np.clip(h, a_min=qmin, a_max=qmax)
|
49 |
+
h = (h-h.min()) / (h.max() - h.min())
|
50 |
+
h = cmap(h, bytes=True)[:,:,:3]
|
51 |
+
h = Image.fromarray(h).resize(img.size, resample=Image.Resampling.BILINEAR)
|
52 |
+
im = Image.blend(img, h, alpha=0.6)
|
53 |
+
im = ax.imshow(np.array(im))
|
54 |
+
# fig.colorbar(im)
|
55 |
+
# plt.grid(False)
|
56 |
+
# plt.axis("off")
|
57 |
fig.tight_layout()
|
58 |
return fig
|
59 |
|
60 |
+
def run_inference(input_img, preset="edm2-img64-s-fid", device="cuda"):
|
|
|
61 |
|
62 |
+
# img = center_crop_imagenet(64, img)
|
63 |
+
input_img = input_img.resize(size=(64, 64), resample=Image.Resampling.LANCZOS)
|
64 |
|
65 |
with torch.inference_mode():
|
66 |
+
img = np.array(input_img)
|
67 |
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
|
68 |
+
img = img.float().to(device)
|
|
|
69 |
model = load_model(modeldir='models', preset=preset, device=device)
|
70 |
+
img_likelihood = model(img).cpu().numpy()
|
71 |
+
|
72 |
+
img = torch.nn.functional.interpolate(img, size=64, mode='bilinear')
|
73 |
x = model.scorenet(img)
|
74 |
x = x.square().sum(dim=(2, 3, 4)) ** 0.5
|
|
|
75 |
nll, pct, ref_nll = compute_gmm_likelihood(x.cpu(), model_dir=f"models/{preset}")
|
76 |
+
|
77 |
outstr = f"Anomaly score: {nll:.3f} / {pct:.2f} percentile"
|
78 |
histplot = plot_against_reference(nll, ref_nll)
|
79 |
+
heatmapplot = plot_heatmap(input_img, img_likelihood)
|
80 |
|
81 |
return outstr, heatmapplot, histplot
|
82 |
|
83 |
|
84 |
demo = gr.Interface(
|
85 |
fn=run_inference,
|
86 |
+
inputs=[gr.Image(type='pil', label="Input Image")],
|
87 |
outputs=["text",
|
88 |
gr.Plot(label="Anomaly Heatmap"),
|
89 |
gr.Plot(label="Comparing to Imagenette"),
|