File size: 4,762 Bytes
3f38b4a dcc7d28 8e487ef 3f38b4a 8e487ef 3f38b4a 8e487ef 3f38b4a 8e487ef 3f38b4a 8e487ef 3f38b4a 8e487ef 9c8a691 8e487ef 52e7bd3 8e487ef 52e7bd3 5dda0e8 8e487ef 9c8a691 8e487ef 9c8a691 8e487ef 5f71bd1 8e487ef 9c8a691 8e487ef 9c8a691 8e487ef 3f38b4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import torch, torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnet import custom_ResNet
import gradio as gr
import os
model = custom_ResNet()
model.load_state_dict(torch.load("custom_resnet_model.pth", map_location=torch.device('cpu')), strict=False)
model.setup(stage="test")
inv_normalize = transforms.Normalize(
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
std=[1/0.23, 1/0.23, 1/0.23]
)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
def inference(input_img, transparency=0.5, target_layer_number=-1, top_classes=3):
transform = transforms.ToTensor()
org_img = input_img
input_img = transform(input_img)
input_img = input_img.unsqueeze(0)
outputs = model(input_img)
softmax = torch.nn.Softmax(dim=1) # Use dim=1 to compute softmax along the classes dimension
o = softmax(outputs)
confidences = {classes[i]: float(o[0, i]) for i in range(10)}
_, prediction = torch.max(outputs, 1)
target_layers = [model.convblock2_l1]
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
img = input_img.squeeze(0)
img = inv_normalize(img)
rgb_img = np.transpose(img, (1, 2, 0))
rgb_img = rgb_img.numpy()
visualization = show_cam_on_image(org_img / 255, grayscale_cam, use_rgb=True, image_weight=transparency)
# Sort the confidences dictionary by values in descending order
sorted_confidences = {k: v for k, v in sorted(confidences.items(), key=lambda item: item[1], reverse=True)}
# Take the top `top_classes` elements from the sorted_confidences
top_classes_confidences = {k: sorted_confidences[k] for k in list(sorted_confidences)[:top_classes]}
return top_classes_confidences, visualization
# Create a wrapper function for show_misclassified_images()
def show_misclassified_images_wrapper(num_images=10, use_gradcam=False, gradcam_layer=-1, transparency=0.5):
transparency = float(transparency)
num_images = int(num_images)
if use_gradcam == "Yes":
use_gradcam = True
else:
use_gradcam = False
return model.show_misclassified_images(num_images, use_gradcam, gradcam_layer, transparency)
description1 = "Supported Only - plane, car, bird, cat, deer, dog, frog, horse, ship, truck."
# Define the full path to the images folder
images_folder = "examples"
# Define the examples list with full paths
examples = [[os.path.join(images_folder, "plane.jpg"), 0.5, -1],
[os.path.join(images_folder, "car.jpg"), 0.5, -1],
[os.path.join(images_folder, "bird.jpg"), 0.5, -1],
[os.path.join(images_folder, "cat.jpg"), 0.5, -1],
[os.path.join(images_folder, "deer.jpg"), 0.5, -1],
[os.path.join(images_folder, "dog.jpg"), 0.5, -1],
[os.path.join(images_folder, "frog.jpg"), 0.5, -1],
[os.path.join(images_folder, "horse.jpg"), 0.5, -1],
[os.path.join(images_folder, "ship.jpg"), 0.5, -1],
[os.path.join(images_folder, "truck.jpg"), 0.5, -1]]
# Create a separate interface for the "Input an image" tab
input_interface = gr.Interface(inference,
inputs=[gr.Image(shape=(32, 32), label="Input Image"),
gr.Slider(0, 1, value=0.5, label="Opacity of GradCAM"),
gr.Slider(-2, -1, value=-2, step=1, label="Which Layer?"),
gr.Slider(1, 10, value=3, step=1, label="How many classes")],
outputs=[gr.Label(),
gr.Image(shape=(32, 32), label="Predicted Output").style(width=300, height=300)],
description=description1,examples=examples)
description2 = "Missclassfied Images"
# Create a separate interface for the "Misclassified Images" tab
misclassified_interface = gr.Interface(show_misclassified_images_wrapper,
inputs=[gr.Number(value=10, label="Number of Images for display"),
gr.Radio(["Yes", "No"], value="No" , label="Show GradCAM outputs"),
gr.Slider(-2, -1, value=-1, step=1, label="Which layer for GradCAM?"),
gr.Slider(0, 1, value=0.5, label="Opacity of GradCAM")],
outputs=gr.Plot(), description=description2)
demo = gr.TabbedInterface([input_interface, misclassified_interface], tab_names=["Input an image", "Misclassified Images"],
title="Gradcam using Cifar10 with CustomResnet Model")
demo.launch() |