padmanabhbosamia commited on
Commit
3d9390d
·
1 Parent(s): bc84d26

Create resnet.py

Browse files
Files changed (1) hide show
  1. resnet.py +309 -0
resnet.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchsummary import summary
5
+ from io import BytesIO
6
+ import numpy as np
7
+ import os
8
+ from pytorch_lightning import LightningModule, Trainer
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ from torch.utils.data import DataLoader, random_split
12
+ from torchmetrics import Accuracy
13
+ from torchvision import transforms
14
+ from torchvision.datasets import CIFAR10
15
+ from torch_lr_finder import LRFinder
16
+ import math
17
+ from pytorch_grad_cam import GradCAM
18
+ from pytorch_grad_cam.utils.image import show_cam_on_image
19
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
20
+ from PIL import Image
21
+ import torch
22
+ from torch.utils.data import DataLoader, random_split
23
+ import torchvision.transforms as transforms
24
+ import torchvision.datasets as datasets
25
+ import pytorch_lightning as pl
26
+ import matplotlib.pyplot as plt
27
+ import matplotlib.gridspec as gridspec
28
+
29
+
30
+ PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
31
+ BATCH_SIZE = 256
32
+
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+ from torchsummary import summary
38
+ from io import BytesIO
39
+ import numpy as np
40
+
41
+ # Model
42
+ class custom_ResNet(pl.LightningModule):
43
+ def __init__(self, data_dir=PATH_DATASETS):
44
+ super(custom_ResNet, self).__init__()
45
+
46
+ # Set our init args as class attributes
47
+ # Hardcode some dataset specific attributes
48
+ self.data_dir = data_dir
49
+ self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
50
+ self.num_classes = 10
51
+ self.train_transform = transforms.Compose([
52
+ transforms.RandomCrop(32, padding=4),
53
+ transforms.RandomHorizontalFlip(),
54
+ transforms.ToTensor(), # Convert PIL image to tensor
55
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
56
+ ])
57
+
58
+ self.test_transform = transforms.Compose([
59
+ transforms.ToTensor(), # Convert PIL image to tensor
60
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
61
+ ])
62
+
63
+ # Define PyTorch model
64
+ # PREPARATION BLOCK
65
+ self.prepblock = nn.Sequential(
66
+ nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
67
+ nn.ReLU(),nn.BatchNorm2d(64))
68
+ # output_size = 32, RF=3
69
+
70
+
71
+ # CONVOLUTION BLOCK 1
72
+ self.convblock1_l1 = nn.Sequential(
73
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
74
+ # output_size = 32, RF=5
75
+ nn.MaxPool2d(2, 2),nn.ReLU(),nn.BatchNorm2d(128))
76
+ # output_size = 16, RF=6
77
+
78
+ self.convblock1_r1 = nn.Sequential(
79
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
80
+ nn.ReLU(),nn.BatchNorm2d(128),
81
+ # output_size = 16, RF=10
82
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
83
+ nn.ReLU(),nn.BatchNorm2d(128))
84
+ # output_size = 16, RF=14
85
+
86
+
87
+ # CONVOLUTION BLOCK 2
88
+ self.convblock2_l1 = nn.Sequential(
89
+ nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
90
+ # output_size = 16, RF=18
91
+ nn.MaxPool2d(2, 2),nn.ReLU(),nn.BatchNorm2d(256))
92
+ # output_size = 8, RF=20
93
+
94
+
95
+ # CONVOLUTION BLOCK 3
96
+ self.convblock3_l1 = nn.Sequential(
97
+ nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
98
+ # output_size = 8, RF=28
99
+ nn.MaxPool2d(2, 2),
100
+ nn.ReLU(),nn.BatchNorm2d(512))
101
+ # output_size = 4, RF=32
102
+
103
+
104
+ self.convblock3_r2 = nn.Sequential(
105
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
106
+ nn.ReLU(),nn.BatchNorm2d(512),
107
+ # output_size = 4, RF=48
108
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
109
+ nn.ReLU(),nn.BatchNorm2d(512))
110
+ # output_size = 4, RF=64
111
+
112
+
113
+ # CONVOLUTION BLOCK 4
114
+ self.convblock4_mp = nn.Sequential(nn.MaxPool2d(4))
115
+ # output_size = 1, RF = 88
116
+
117
+
118
+ # OUTPUT BLOCK - Fully Connected layer
119
+ self.output_block = nn.Sequential(nn.Linear(in_features=512, out_features=10, bias=False))
120
+ # output_size = 1, RF = 88
121
+
122
+
123
+ def forward(self, x):
124
+
125
+ # Preparation Block
126
+ x1 = self.prepblock(x)
127
+
128
+ # Convolution Block 1
129
+ x2 = self.convblock1_l1(x1)
130
+ x3 = self.convblock1_r1(x2)
131
+ x4 = x2 + x3
132
+
133
+ # Convolution Block 2
134
+ x5 = self.convblock2_l1(x4)
135
+
136
+ # Convolution Block 3
137
+ x6 = self.convblock3_l1(x5)
138
+ x7 = self.convblock3_r2(x6)
139
+ x8 = x7 + x6
140
+
141
+ # Convolution Block 4
142
+ x9 = self.convblock4_mp(x8)
143
+
144
+ # Output Block
145
+ x9 = x9.view(x9.size(0), -1)
146
+ x10 = self.output_block(x9)
147
+ return F.log_softmax(x10, dim=1)
148
+
149
+ def training_step(self, batch, batch_idx):
150
+ x, y = batch
151
+ y_hat = self.forward(x)
152
+ loss = F.cross_entropy(y_hat, y)
153
+ pred = y_hat.argmax(dim=1, keepdim=True)
154
+ acc = pred.eq(y.view_as(pred)).float().mean()
155
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
156
+ self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
157
+ return loss
158
+
159
+ def validation_step(self, batch, batch_idx):
160
+ x, y = batch
161
+ y_hat = self.forward(x)
162
+ loss = F.cross_entropy(y_hat, y)
163
+ pred = y_hat.argmax(dim=1, keepdim=True)
164
+ acc = pred.eq(y.view_as(pred)).float().mean()
165
+ self.log('val_loss', loss, prog_bar=True)
166
+ self.log('val_acc', acc, prog_bar=True)
167
+ return loss
168
+
169
+ def test_step(self, batch, batch_idx):
170
+ x, y = batch
171
+ y_hat = self.forward(x)
172
+ loss = F.cross_entropy(y_hat, y)
173
+ pred = y_hat.argmax(dim=1, keepdim=True)
174
+ acc = pred.eq(y.view_as(pred)).float().mean()
175
+ self.log('test_loss', loss, prog_bar=True)
176
+ self.log('test_acc', acc, prog_bar=True)
177
+ return pred
178
+
179
+ def configure_optimizers(self):
180
+ optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
181
+ return optimizer
182
+
183
+
184
+ ####################
185
+ # DATA RELATED HOOKS
186
+ ####################
187
+
188
+ def prepare_data(self):
189
+ # download
190
+ CIFAR10(self.data_dir, train=True, download=True)
191
+ CIFAR10(self.data_dir, train=False, download=True)
192
+
193
+ def setup(self, stage=None):
194
+
195
+ # Assign train/val datasets for use in dataloaders
196
+ if stage == "fit" or stage is None:
197
+ cifar_full = CIFAR10(self.data_dir, train=True, download=True, transform=self.train_transform)
198
+ self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
199
+
200
+ # Assign test dataset for use in dataloader(s)
201
+ if stage == "test" or stage is None:
202
+ self.cifar_test = CIFAR10(self.data_dir, train=False, download=True, transform=self.test_transform)
203
+
204
+ def train_dataloader(self):
205
+ return DataLoader(self.cifar_train, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
206
+
207
+ def val_dataloader(self):
208
+ return DataLoader(self.cifar_val, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
209
+
210
+ def test_dataloader(self):
211
+ return DataLoader(self.cifar_test, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
212
+
213
+ def collect_misclassified_images(self, num_images):
214
+ misclassified_images = []
215
+ misclassified_true_labels = []
216
+ misclassified_predicted_labels = []
217
+ num_collected = 0
218
+
219
+ for batch in self.test_dataloader():
220
+ x, y = batch
221
+ y_hat = self.forward(x)
222
+ pred = y_hat.argmax(dim=1, keepdim=True)
223
+ misclassified_mask = pred.eq(y.view_as(pred)).squeeze()
224
+ misclassified_images.extend(x[~misclassified_mask].detach())
225
+ misclassified_true_labels.extend(y[~misclassified_mask].detach())
226
+ misclassified_predicted_labels.extend(pred[~misclassified_mask].detach())
227
+
228
+ num_collected += sum(~misclassified_mask)
229
+
230
+ if num_collected >= num_images:
231
+ break
232
+
233
+ return misclassified_images[:num_images], misclassified_true_labels[:num_images], misclassified_predicted_labels[:num_images], len(misclassified_images)
234
+
235
+
236
+ def normalize_image(self, img_tensor):
237
+ min_val = img_tensor.min()
238
+ max_val = img_tensor.max()
239
+ return (img_tensor - min_val) / (max_val - min_val)
240
+
241
+ def get_gradcam_images(self, target_layer=-1, transparency=0.5, num_images=10):
242
+ misclassified_images, true_labels, predicted_labels, num_misclassified = self.collect_misclassified_images(num_images)
243
+ count = 0
244
+ k = 0
245
+ misclassified_images_converted = list()
246
+ gradcam_images = list()
247
+
248
+ if target_layer == -2:
249
+ target_layer = self.convblock2_l1.cpu()
250
+ else:
251
+ target_layer = self.convblock3_l1.cpu()
252
+
253
+ dataset_mean, dataset_std = np.array([0.49139968, 0.48215841, 0.44653091]), np.array([0.24703223, 0.24348513, 0.26158784])
254
+ grad_cam = GradCAM(model=self.cpu(), target_layers=target_layer, use_cuda=False) # Move model to CPU
255
+
256
+ for i in range(0, num_images):
257
+ img_converted = misclassified_images[i].cpu().numpy().transpose(1, 2, 0) # Convert tensor to numpy and transpose to (H, W, C)
258
+ img_converted = dataset_std * img_converted + dataset_mean
259
+ img_converted = np.clip(img_converted, 0, 1)
260
+ misclassified_images_converted.append(img_converted)
261
+ targets = [ClassifierOutputTarget(true_labels[i])]
262
+ grayscale_cam = grad_cam(input_tensor=misclassified_images[i].unsqueeze(0).cpu(), targets=targets) # Move input to CPU
263
+ grayscale_cam = grayscale_cam[0, :]
264
+ output = show_cam_on_image(img_converted, grayscale_cam, use_rgb=True, image_weight=transparency)
265
+ gradcam_images.append(output)
266
+
267
+ return gradcam_images
268
+
269
+ def create_layout(self, num_images, use_gradcam):
270
+ num_cols = 3 if use_gradcam else 2
271
+ fig = plt.figure(figsize=(12, 5 * num_images))
272
+ gs = gridspec.GridSpec(num_images, num_cols, figure=fig, width_ratios=[0.3, 1, 1] if use_gradcam else [0.5, 1])
273
+
274
+ return fig, gs
275
+
276
+ def show_images_with_labels(self, fig, gs, i, img, label_text, use_gradcam=False, gradcam_img=None):
277
+ ax_img = fig.add_subplot(gs[i, 1])
278
+ ax_img.imshow(img)
279
+ ax_img.set_title("Original Image")
280
+ ax_img.axis("off")
281
+
282
+ if use_gradcam:
283
+ ax_gradcam = fig.add_subplot(gs[i, 2])
284
+ ax_gradcam.imshow(gradcam_img)
285
+ ax_gradcam.set_title("GradCAM Image")
286
+ ax_gradcam.axis("off")
287
+
288
+ ax_label = fig.add_subplot(gs[i, 0])
289
+ ax_label.text(0, 0.5, label_text, fontsize=10, verticalalignment='center')
290
+ ax_label.axis("off")
291
+
292
+ def show_misclassified_images(self, num_images=10, use_gradcam=False, gradcam_layer=-1, transparency=0.5):
293
+ misclassified_images, true_labels, predicted_labels, num_misclassified = self.collect_misclassified_images(num_images)
294
+
295
+ fig, gs = self.create_layout(num_images, use_gradcam)
296
+
297
+ if use_gradcam:
298
+ grad_cam_images = self.get_gradcam_images(target_layer=gradcam_layer, transparency=transparency, num_images=num_images)
299
+
300
+ for i in range(num_images):
301
+ img = misclassified_images[i].numpy().transpose((1, 2, 0)) # Convert tensor to numpy and transpose to (H, W, C)
302
+ img = self.normalize_image(img) # Normalize the image
303
+
304
+ # Show true label and predicted label on the left, and images on the right
305
+ label_text = f"True Label: {self.classes[true_labels[i]]}\nPredicted Label: {self.classes[predicted_labels[i]]}"
306
+ self.show_images_with_labels(fig, gs, i, img, label_text, use_gradcam, grad_cam_images[i] if use_gradcam else None)
307
+
308
+ plt.tight_layout()
309
+ return fig