|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchsummary import summary |
|
from io import BytesIO |
|
import numpy as np |
|
import os |
|
from pytorch_lightning import LightningModule, Trainer |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from torch.utils.data import DataLoader, random_split |
|
|
|
from torchvision import transforms |
|
from torchvision.datasets import CIFAR10 |
|
from torch_lr_finder import LRFinder |
|
import math |
|
from pytorch_grad_cam import GradCAM |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
|
from PIL import Image |
|
import torch |
|
from torch.utils.data import DataLoader, random_split |
|
import torchvision.transforms as transforms |
|
import torchvision.datasets as datasets |
|
import pytorch_lightning as pl |
|
import matplotlib.pyplot as plt |
|
import matplotlib.gridspec as gridspec |
|
|
|
|
|
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".") |
|
BATCH_SIZE = 256 |
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchsummary import summary |
|
from io import BytesIO |
|
import numpy as np |
|
|
|
|
|
class custom_ResNet(pl.LightningModule): |
|
def __init__(self, data_dir=PATH_DATASETS): |
|
super(custom_ResNet, self).__init__() |
|
|
|
|
|
|
|
self.data_dir = data_dir |
|
self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
self.num_classes = 10 |
|
self.train_transform = transforms.Compose([ |
|
transforms.RandomCrop(32, padding=4), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) |
|
]) |
|
|
|
self.test_transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) |
|
]) |
|
|
|
|
|
|
|
self.prepblock = nn.Sequential( |
|
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), |
|
nn.ReLU(),nn.BatchNorm2d(64)) |
|
|
|
|
|
|
|
|
|
self.convblock1_l1 = nn.Sequential( |
|
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), |
|
|
|
nn.MaxPool2d(2, 2),nn.ReLU(),nn.BatchNorm2d(128)) |
|
|
|
|
|
self.convblock1_r1 = nn.Sequential( |
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), |
|
nn.ReLU(),nn.BatchNorm2d(128), |
|
|
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), |
|
nn.ReLU(),nn.BatchNorm2d(128)) |
|
|
|
|
|
|
|
|
|
self.convblock2_l1 = nn.Sequential( |
|
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), |
|
|
|
nn.MaxPool2d(2, 2),nn.ReLU(),nn.BatchNorm2d(256)) |
|
|
|
|
|
|
|
|
|
self.convblock3_l1 = nn.Sequential( |
|
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), |
|
|
|
nn.MaxPool2d(2, 2), |
|
nn.ReLU(),nn.BatchNorm2d(512)) |
|
|
|
|
|
|
|
self.convblock3_r2 = nn.Sequential( |
|
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), |
|
nn.ReLU(),nn.BatchNorm2d(512), |
|
|
|
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), |
|
nn.ReLU(),nn.BatchNorm2d(512)) |
|
|
|
|
|
|
|
|
|
self.convblock4_mp = nn.Sequential(nn.MaxPool2d(4)) |
|
|
|
|
|
|
|
|
|
self.output_block = nn.Sequential(nn.Linear(in_features=512, out_features=10, bias=False)) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x1 = self.prepblock(x) |
|
|
|
|
|
x2 = self.convblock1_l1(x1) |
|
x3 = self.convblock1_r1(x2) |
|
x4 = x2 + x3 |
|
|
|
|
|
x5 = self.convblock2_l1(x4) |
|
|
|
|
|
x6 = self.convblock3_l1(x5) |
|
x7 = self.convblock3_r2(x6) |
|
x8 = x7 + x6 |
|
|
|
|
|
x9 = self.convblock4_mp(x8) |
|
|
|
|
|
x9 = x9.view(x9.size(0), -1) |
|
x10 = self.output_block(x9) |
|
return F.log_softmax(x10, dim=1) |
|
|
|
def training_step(self, batch, batch_idx): |
|
x, y = batch |
|
y_hat = self.forward(x) |
|
loss = F.cross_entropy(y_hat, y) |
|
pred = y_hat.argmax(dim=1, keepdim=True) |
|
acc = pred.eq(y.view_as(pred)).float().mean() |
|
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) |
|
self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
x, y = batch |
|
y_hat = self.forward(x) |
|
loss = F.cross_entropy(y_hat, y) |
|
pred = y_hat.argmax(dim=1, keepdim=True) |
|
acc = pred.eq(y.view_as(pred)).float().mean() |
|
self.log('val_loss', loss, prog_bar=True) |
|
self.log('val_acc', acc, prog_bar=True) |
|
return loss |
|
|
|
def test_step(self, batch, batch_idx): |
|
x, y = batch |
|
y_hat = self.forward(x) |
|
loss = F.cross_entropy(y_hat, y) |
|
pred = y_hat.argmax(dim=1, keepdim=True) |
|
acc = pred.eq(y.view_as(pred)).float().mean() |
|
self.log('test_loss', loss, prog_bar=True) |
|
self.log('test_acc', acc, prog_bar=True) |
|
return pred |
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.Adam(self.parameters(), lr=0.001) |
|
return optimizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_data(self): |
|
|
|
CIFAR10(self.data_dir, train=True, download=True) |
|
CIFAR10(self.data_dir, train=False, download=True) |
|
|
|
def setup(self, stage=None): |
|
|
|
|
|
if stage == "fit" or stage is None: |
|
cifar_full = CIFAR10(self.data_dir, train=True, download=True, transform=self.train_transform) |
|
self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000]) |
|
|
|
|
|
if stage == "test" or stage is None: |
|
self.cifar_test = CIFAR10(self.data_dir, train=False, download=True, transform=self.test_transform) |
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.cifar_train, batch_size=BATCH_SIZE, num_workers=os.cpu_count()) |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.cifar_val, batch_size=BATCH_SIZE, num_workers=os.cpu_count()) |
|
|
|
def test_dataloader(self): |
|
return DataLoader(self.cifar_test, batch_size=BATCH_SIZE, num_workers=os.cpu_count()) |
|
|
|
def collect_misclassified_images(self, num_images): |
|
misclassified_images = [] |
|
misclassified_true_labels = [] |
|
misclassified_predicted_labels = [] |
|
num_collected = 0 |
|
|
|
for batch in self.test_dataloader(): |
|
x, y = batch |
|
y_hat = self.forward(x) |
|
pred = y_hat.argmax(dim=1, keepdim=True) |
|
misclassified_mask = pred.eq(y.view_as(pred)).squeeze() |
|
misclassified_images.extend(x[~misclassified_mask].detach()) |
|
misclassified_true_labels.extend(y[~misclassified_mask].detach()) |
|
misclassified_predicted_labels.extend(pred[~misclassified_mask].detach()) |
|
|
|
num_collected += sum(~misclassified_mask) |
|
|
|
if num_collected >= num_images: |
|
break |
|
|
|
return misclassified_images[:num_images], misclassified_true_labels[:num_images], misclassified_predicted_labels[:num_images], len(misclassified_images) |
|
|
|
|
|
def normalize_image(self, img_tensor): |
|
min_val = img_tensor.min() |
|
max_val = img_tensor.max() |
|
return (img_tensor - min_val) / (max_val - min_val) |
|
|
|
def get_gradcam_images(self, target_layer=-1, transparency=0.5, num_images=10): |
|
misclassified_images, true_labels, predicted_labels, num_misclassified = self.collect_misclassified_images(num_images) |
|
count = 0 |
|
k = 0 |
|
misclassified_images_converted = list() |
|
gradcam_images = list() |
|
|
|
if target_layer == -2: |
|
target_layer = self.convblock2_l1.cpu() |
|
else: |
|
target_layer = self.convblock3_l1.cpu() |
|
|
|
dataset_mean, dataset_std = np.array([0.49139968, 0.48215841, 0.44653091]), np.array([0.24703223, 0.24348513, 0.26158784]) |
|
grad_cam = GradCAM(model=self.cpu(), target_layers=target_layer, use_cuda=False) |
|
|
|
for i in range(0, num_images): |
|
img_converted = misclassified_images[i].cpu().numpy().transpose(1, 2, 0) |
|
img_converted = dataset_std * img_converted + dataset_mean |
|
img_converted = np.clip(img_converted, 0, 1) |
|
misclassified_images_converted.append(img_converted) |
|
targets = [ClassifierOutputTarget(true_labels[i])] |
|
grayscale_cam = grad_cam(input_tensor=misclassified_images[i].unsqueeze(0).cpu(), targets=targets) |
|
grayscale_cam = grayscale_cam[0, :] |
|
output = show_cam_on_image(img_converted, grayscale_cam, use_rgb=True, image_weight=transparency) |
|
gradcam_images.append(output) |
|
|
|
return gradcam_images |
|
|
|
def create_layout(self, num_images, use_gradcam): |
|
num_cols = 3 if use_gradcam else 2 |
|
fig = plt.figure(figsize=(12, 5 * num_images)) |
|
gs = gridspec.GridSpec(num_images, num_cols, figure=fig, width_ratios=[0.3, 1, 1] if use_gradcam else [0.5, 1]) |
|
|
|
return fig, gs |
|
|
|
def show_images_with_labels(self, fig, gs, i, img, label_text, use_gradcam=False, gradcam_img=None): |
|
ax_img = fig.add_subplot(gs[i, 1]) |
|
ax_img.imshow(img) |
|
ax_img.set_title("Original Image") |
|
ax_img.axis("off") |
|
|
|
if use_gradcam: |
|
ax_gradcam = fig.add_subplot(gs[i, 2]) |
|
ax_gradcam.imshow(gradcam_img) |
|
ax_gradcam.set_title("GradCAM Image") |
|
ax_gradcam.axis("off") |
|
|
|
ax_label = fig.add_subplot(gs[i, 0]) |
|
ax_label.text(0, 0.5, label_text, fontsize=10, verticalalignment='center') |
|
ax_label.axis("off") |
|
|
|
def show_misclassified_images(self, num_images=10, use_gradcam=False, gradcam_layer=-1, transparency=0.5): |
|
misclassified_images, true_labels, predicted_labels, num_misclassified = self.collect_misclassified_images(num_images) |
|
|
|
fig, gs = self.create_layout(num_images, use_gradcam) |
|
|
|
if use_gradcam: |
|
grad_cam_images = self.get_gradcam_images(target_layer=gradcam_layer, transparency=transparency, num_images=num_images) |
|
|
|
for i in range(num_images): |
|
img = misclassified_images[i].numpy().transpose((1, 2, 0)) |
|
img = self.normalize_image(img) |
|
|
|
|
|
label_text = f"True Label: {self.classes[true_labels[i]]}\nPredicted Label: {self.classes[predicted_labels[i]]}" |
|
self.show_images_with_labels(fig, gs, i, img, label_text, use_gradcam, grad_cam_images[i] if use_gradcam else None) |
|
|
|
plt.tight_layout() |
|
return fig |