File size: 3,478 Bytes
b059e33
 
 
 
 
 
7cfb8be
 
b059e33
509b38e
 
 
 
 
 
 
b059e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76cb22f
b059e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76cb22f
b059e33
4c8eb7e
 
 
 
b059e33
4c8eb7e
b059e33
 
ee191cc
b059e33
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import tensorflow as tf
import numpy as np
from tensorflow import keras
from keras.layers import Input, Lambda, Dense, Flatten, Rescaling
from keras.models import Model
import PIL
from PIL import Image
import gradio as gr

base_model = keras.applications.Xception(
    # weights = "../input/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5",
    input_shape = (160,160,3),
    include_top = False,)

base_model.trainable = False

#function for creating model
#returns model, its inputs, Xception's last conv output, the whole model's outputs
def create_model_mod():
    inputs = keras.Input(shape = (160,160,3))
    #normalizing pixel values
    r = Rescaling(scale = 1./255)(inputs)
    x = base_model(r, training = False)
    gap = keras.layers.GlobalAveragePooling2D()(x)
    outputs = keras.layers.Dense(1,activation = 'linear')(gap)
    model = keras.Model(inputs, outputs)
    
    model.compile(
    loss = keras.losses.BinaryCrossentropy(from_logits = True),
    optimizer = keras.optimizers.Adam(0.001),
    metrics = ["accuracy"]
    )
    
    return model, inputs, x, outputs

def create_heatmap(model, imgs):
    #predicting the images and getting the conv outputs and predictions
    with tf.GradientTape() as tape:
        maps, preds = model(imgs);
    
    #computing gradients of predictions w.r.t the feature maps
    grads = tape.gradient(preds, maps)
    
    # global average pooling of each feature map
    gap_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    
    #multiplying each pooled value with its correponding feature map
    # maps = maps[0]
    heatmap = maps @ gap_grads[..., tf.newaxis]
    
    #removing the extra dimension of value 1
    heatmap = tf.squeeze(heatmap)
    
    #applying relu activation
    heatmap = tf.keras.activations.relu(heatmap)
    
    return heatmap, preds.numpy()

def superimpose_single(heatmap, img, alpha = 0.4):
    heatmap = np.uint8(255 * heatmap)
    
    # Use jet colormap to colorize heatmap
    jet = cm.get_cmap("jet")
    
    # Use RGB values of the colormap
    jet_colors = jet(np.arange(256))[:, :3]
    jet_heatmap = jet_colors[heatmap]
    
    # Create an image with RGB colorized heatmap
    jet_heatmap = keras.utils.array_to_img(jet_heatmap)
    jet_heatmap = jet_heatmap.resize((160,160))
    jet_heatmap = keras.utils.img_to_array(jet_heatmap)
    
    # Superimpose the heatmap on original image
    superimposed_img = jet_heatmap * alpha + img
    # superimposed_img = keras.utils.array_to_img(superimposed_img)
    
    return superimposed_img

def gen_grad_img_single(weights, img, alpha = 0.4):
    model_mod, input, x, output = create_model_mod()
    model_mod.load_weights(weights)
    grad_model = Model(input, [x, output])
    heatmaps, y_pred = create_heatmap(grad_model, img)
    
    for i in range(len(y_pred)):
      if y_pred[i] > 0.5: y_pred[i] = 1
      else: y_pred[i] = 0
    
    img = superimpose_single(heatmaps, img[0])
    return np.array(img).astype('uint8'), y_pred


weights = "weights.h5"
# img, y_pred = gen_grad_img_single(weights, img)

def get_grad(img):
    a = gen_grad_img_single(weights, img)
    return a

demo = gr.Interface(
  fn = get_grad,
  inputs = gr.Image(type = "pil", shape = (224,224)),
  # outputs = [gr.outputs.Label(num_top_classes = 2, label = 'Classifiaction'), gr.Textbox('infer_time', label = 'Inference Time(ms)')]
  outputs = [gr.Image(type = "numpy"), gr.Textbox('y_pred', label = 'Prediction')]
   )
demo.launch()