#myaddition import os import torch from torch.optim.lr_scheduler import LambdaLR import torch.optim as optim from pytorch_lightning import LightningModule from Custom_Resnet_v1 import CustomResNet from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader, random_split import torchvision from torchmetrics.functional import accuracy from torchvision.datasets import CIFAR10 from data_transform_cifar10_custom_resnet import get_train_transform, get_test_transform PATH_DATASETS = os.environ.get("PATH_DATASETS", ".") AVAIL_GPUS = min(1, torch.cuda.device_count()) BATCH_SIZE = 256 if AVAIL_GPUS else 64 from cyclic_lr_util import custom_one_cycle_lr one_cyle_lr = custom_one_cycle_lr(no_of_images=50176, batch_size=2, base_lr=0.04, max_lr=0.4, final_lr=0.004, epoch_stage1=5, epoch_stage2=18, total_epochs=24) class Assignment12Resnet(LightningModule): def __init__(self,lr=0.05,data_dir=PATH_DATASETS): super().__init__() # Set our init args as class attributes self.data_dir = data_dir self.learning_rate = lr # Hardcode some dataset specific attributes self.num_classes = 10 self.train_transform = get_train_transform() self.test_transform = get_test_transform() self.cifar10_trainset = None self.cifar10_testset = None self.save_hyperparameters() self.model = CustomResNet() def forward(self, x): out = self.model(x) return out def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) self.log("train_loss", loss) return loss def evaluate(self, batch, stage=None): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) preds = torch.argmax(logits, dim=1) acc = accuracy(preds, y,task="multiclass", num_classes=10) if stage: self.log(f"{stage}_loss", loss, prog_bar=True) self.log(f"{stage}_acc", acc, prog_bar=True) def validation_step(self, batch, batch_idx): self.evaluate(batch, "val") def test_step(self, batch, batch_idx): self.evaluate(batch, "test") def configure_optimizers(self): optimizer = optim.SGD(self.model.parameters(), lr=0.04, momentum=0.9) steps_per_epoch = 45000 // BATCH_SIZE scheduler_dict = { "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[one_cyle_lr]), "interval": "step", } return {"optimizer": optimizer, "lr_scheduler": scheduler_dict} #################### # DATA RELATED HOOKS #################### def prepare_data(self): # download CIFAR10(self.data_dir, train=True, download=True) CIFAR10(self.data_dir, train=False, download=True) def setup(self, stage=None): # Assign train/val datasets for use in dataloaders if stage == "fit" or stage is None: cifar10_trainset = torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=self.train_transform) self.cifar_train, self.cifar_val = random_split(cifar10_trainset, [46000, 4000]) # Assign test dataset for use in dataloader(s) if stage == "test" or stage is None: self.cifar10_testset= torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=self.test_transform) def train_dataloader(self): return torch.utils.data.DataLoader(self.cifar_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=os.cpu_count()) def val_dataloader(self): return torch.utils.data.DataLoader(self.cifar_val, batch_size=BATCH_SIZE,shuffle=False, num_workers=os.cpu_count()) def test_dataloader(self): return torch.utils.data.DataLoader(self.cifar10_testset, batch_size=BATCH_SIZE, shuffle=False,num_workers=os.cpu_count())