Spaces:
Runtime error
Runtime error
File size: 3,988 Bytes
f740d84 d57d2f2 dfbc387 f740d84 9c45667 f740d84 dfbc387 11bce97 f740d84 11bce97 f740d84 dfbc387 f740d84 dfbc387 f740d84 dfbc387 f740d84 dfbc387 f740d84 dfbc387 f740d84 dfbc387 f740d84 dfbc387 f740d84 b037fad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 as keras_model
from tensorflow.keras.applications.mobilenet_v2 import (
preprocess_input,
decode_predictions,
)
import matplotlib.pyplot as plt
from alibi.explainers import IntegratedGradients
from alibi.datasets import load_cats
from alibi.utils.visualization import visualize_image_attr
import numpy as np
from PIL import Image, ImageFilter
import io
import time
import os
import copy
import pickle
import datetime
import urllib.request
import gradio as gr
url = (
"https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
)
path_input = "./cat.jpg"
urllib.request.urlretrieve(url, filename=path_input)
url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
path_input = "./dog.jpg"
urllib.request.urlretrieve(url, filename=path_input)
model = keras_model(weights="imagenet")
n_steps = 50
method = "gausslegendre"
internal_batch_size = 50
ig = IntegratedGradients(
model, n_steps=n_steps, method=method, internal_batch_size=internal_batch_size
)
def do_process(img, baseline):
instance = image.img_to_array(img)
instance = np.expand_dims(instance, axis=0)
instance = preprocess_input(instance)
preds = model.predict(instance)
lstPreds = decode_predictions(preds, top=3)[0]
dctPreds = {
lstPreds[i][1]: round(float(lstPreds[i][2]), 2) for i in range(len(lstPreds))
}
predictions = preds.argmax(axis=1)
if baseline == "white":
baselines = bls = np.ones(instance.shape).astype(instance.dtype)
img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
elif baseline == "black":
baselines = bls = np.zeros(instance.shape).astype(instance.dtype)
img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
elif baseline == "blur":
img_flt = img.filter(ImageFilter.GaussianBlur(5))
baselines = image.img_to_array(img_flt)
baselines = np.expand_dims(baselines, axis=0)
baselines = preprocess_input(baselines)
else:
baselines = np.random.random_sample(instance.shape).astype(instance.dtype)
img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
explanation = ig.explain(instance, baselines=baselines, target=predictions)
attrs = explanation.attributions[0]
fig, ax = visualize_image_attr(
attr=attrs.squeeze(),
original_image=img,
method="blended_heat_map",
sign="all",
show_colorbar=True,
title=baseline,
plt_fig_axis=None,
use_pyplot=False,
)
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img_res = Image.open(buf)
return img_res, img_flt, dctPreds
input_im = gr.inputs.Image(
shape=(224, 224), image_mode="RGB", invert_colors=False, source="upload", type="pil"
)
input_drop = gr.inputs.Dropdown(
label="Baseline (default: random)",
choices=["random", "black", "white", "blur"],
default="random",
type="value",
)
output_img = gr.outputs.Image(label="Output of Integrated Gradients", type="pil")
output_base = gr.outputs.Image(label="Baseline image", type="pil")
output_label = gr.outputs.Label(label="Classification results", num_top_classes=3)
title = "XAI - Integrated gradients"
description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio."
examples = [["./cat.jpg", "blur"], ["./dog.jpg", "random"]]
article = "<p style='text-align: center'><a href='https://github.com/mawady' target='_blank'>By Dr. Mohamed Elawady</a></p>"
iface = gr.Interface(
fn=do_process,
inputs=[input_im, input_drop],
outputs=[output_img, output_base, output_label],
live=False,
interpretation=None,
title=title,
description=description,
article=article,
examples=examples,
)
iface.launch(debug=True)
|