Spaces:
Runtime error
Runtime error
File size: 3,817 Bytes
b059e33 7cfb8be 080f6e6 b059e33 509b38e 8849d6d b059e33 76cb22f b059e33 fe1c53b b059e33 76cb22f b059e33 4c8eb7e 8849d6d fe1c53b 4a989c6 fe1c53b 9849d45 fe1c53b 4c8eb7e b059e33 4c8eb7e b059e33 421625a b059e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import tensorflow as tf
import numpy as np
from tensorflow import keras
from keras.layers import Input, Lambda, Dense, Flatten, Rescaling
from keras.models import Model
import PIL
from PIL import Image
import gradio as gr
import matplotlib.cm as cm
base_model = keras.applications.Xception(
# weights = "../input/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5",
input_shape = (160,160,3),
include_top = False,)
base_model.trainable = False
def img_pros(img):
img = tf.keras.preprocessing.image.img_to_array(img)
img = tf.image.resize(img, [160,160])
img = tf.expand_dims(img, axis = 0)
return img
#function for creating model
#returns model, its inputs, Xception's last conv output, the whole model's outputs
def create_model_mod():
inputs = keras.Input(shape = (160,160,3))
#normalizing pixel values
r = Rescaling(scale = 1./255)(inputs)
x = base_model(r, training = False)
gap = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1,activation = 'linear')(gap)
model = keras.Model(inputs, outputs)
model.compile(
loss = keras.losses.BinaryCrossentropy(from_logits = True),
optimizer = keras.optimizers.Adam(0.001),
metrics = ["accuracy"]
)
return model, inputs, x, outputs
def create_heatmap(model, imgs):
#predicting the images and getting the conv outputs and predictions
with tf.GradientTape() as tape:
maps, preds = model(imgs);
#computing gradients of predictions w.r.t the feature maps
grads = tape.gradient(preds, maps)
# global average pooling of each feature map
gap_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
#multiplying each pooled value with its correponding feature map
# maps = maps[0]
heatmap = maps @ gap_grads[..., tf.newaxis]
#removing the extra dimension of value 1
heatmap = tf.squeeze(heatmap)
#applying relu activation
heatmap = tf.keras.activations.relu(heatmap)
return heatmap, preds.numpy()
def superimpose_single(heatmap, img, alpha = 0.4):
heatmap = np.uint8(255 * heatmap)
# Use jet colormap to colorize heatmap
jet = cm.get_cmap("jet")
# Use RGB values of the colormap
jet_colors = jet(np.arange(256))[:, :3]
jet_heatmap = jet_colors[heatmap]
# Create an image with RGB colorized heatmap
jet_heatmap = keras.utils.array_to_img(jet_heatmap)
jet_heatmap = jet_heatmap.resize((160,160))
jet_heatmap = keras.utils.img_to_array(jet_heatmap)
# Superimpose the heatmap on original image
superimposed_img = jet_heatmap * alpha + img
# superimposed_img = keras.utils.array_to_img(superimposed_img)
return superimposed_img
def gen_grad_img_single(weights, img, alpha = 0.4):
model_mod, input, x, output = create_model_mod()
model_mod.load_weights(weights)
grad_model = Model(input, [x, output])
heatmaps, y_pred = create_heatmap(grad_model, img)
# for i in range(len(y_pred)):
# if y_pred[i] > 0.5: y_pred[i] = 1
# else: y_pred[i] = 0
img = superimpose_single(heatmaps, img[0])
return np.array(img).astype('uint8'), y_pred
weights = "weights.h5"
# img, y_pred = gen_grad_img_single(weights, img)
def get_grad(img):
img = img_pros(img)
grad_img, y_pred = gen_grad_img_single(weights, img)
pred_class = ""
if y_pred[0] > 0.5: pred_class = "cat"
else: pred_class = "dog"
text = "Raw Score: " + str(y_pred[0]) + "\nClassification: " + pred_class
return grad_img, text
demo = gr.Interface(
fn = get_grad,
inputs = gr.Image(type = "pil", shape = (224,224)),
outputs = [gr.Image(type = "numpy", width = 320, height = 320), gr.Textbox(label = 'Prediction', info = '(threshold: 0.5)')]
)
demo.launch()
|