import types import time import random import clip import torch import torch.nn as nn import torchvision.transforms as transforms from argparse import ArgumentParser import pytorch_lightning as pl from data import get_dataset, get_available_datasets from encoding.models import get_segmentation_model from encoding.nn import SegmentationLosses from encoding.utils import batch_pix_accuracy, batch_intersection_union # add mixed precision import torch.cuda.amp as amp import numpy as np from encoding.utils import SegmentationMetric class LSegmentationModule(pl.LightningModule): def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs): super().__init__() self.data_path = data_path self.batch_size = batch_size self.base_lr = base_lr / 16 * batch_size self.lr = self.base_lr self.epochs = max_epochs self.other_kwargs = kwargs self.enabled = False #True mixed precision will make things complicated and leading to NAN error self.scaler = amp.GradScaler(enabled=self.enabled) def forward(self, x): return self.net(x) def evaluate(self, x, target=None): pred = self.net.forward(x) if isinstance(pred, (tuple, list)): pred = pred[0] if target is None: return pred correct, labeled = batch_pix_accuracy(pred.data, target.data) inter, union = batch_intersection_union(pred.data, target.data, self.nclass) return correct, labeled, inter, union def evaluate_random(self, x, labelset, target=None): pred = self.net.forward(x, labelset) if isinstance(pred, (tuple, list)): pred = pred[0] if target is None: return pred correct, labeled = batch_pix_accuracy(pred.data, target.data) inter, union = batch_intersection_union(pred.data, target.data, self.nclass) return correct, labeled, inter, union def training_step(self, batch, batch_nb): img, target = batch with amp.autocast(enabled=self.enabled): out = self(img) multi_loss = isinstance(out, tuple) if multi_loss: loss = self.criterion(*out, target) else: loss = self.criterion(out, target) loss = self.scaler.scale(loss) final_output = out[0] if multi_loss else out train_pred, train_gt = self._filter_invalid(final_output, target) if train_gt.nelement() != 0: self.train_accuracy(train_pred, train_gt) self.log("train_loss", loss) return loss def training_epoch_end(self, outs): self.log("train_acc_epoch", self.train_accuracy.compute()) def validation_step(self, batch, batch_nb): img, target = batch out = self(img) multi_loss = isinstance(out, tuple) if multi_loss: val_loss = self.criterion(*out, target) else: val_loss = self.criterion(out, target) final_output = out[0] if multi_loss else out valid_pred, valid_gt = self._filter_invalid(final_output, target) self.val_iou.update(target, final_output) pixAcc, iou = self.val_iou.get() self.log("val_loss_step", val_loss) self.log("pix_acc_step", pixAcc) self.log( "val_acc_step", self.val_accuracy(valid_pred, valid_gt), ) self.log("val_iou", iou) def validation_epoch_end(self, outs): pixAcc, iou = self.val_iou.get() self.log("val_acc_epoch", self.val_accuracy.compute()) self.log("val_iou_epoch", iou) self.log("pix_acc_epoch", pixAcc) self.val_iou.reset() def _filter_invalid(self, pred, target): valid = target != self.other_kwargs["ignore_index"] _, mx = torch.max(pred, dim=1) return mx[valid], target[valid] def configure_optimizers(self): params_list = [ {"params": self.net.pretrained.parameters(), "lr": self.base_lr}, ] if hasattr(self.net, "scratch"): print("Found output scratch") params_list.append( {"params": self.net.scratch.parameters(), "lr": self.base_lr * 10} ) if hasattr(self.net, "auxlayer"): print("Found auxlayer") params_list.append( {"params": self.net.auxlayer.parameters(), "lr": self.base_lr * 10} ) if hasattr(self.net, "scale_inv_conv"): print(self.net.scale_inv_conv) print("Found scaleinv layers") params_list.append( { "params": self.net.scale_inv_conv.parameters(), "lr": self.base_lr * 10, } ) params_list.append( {"params": self.net.scale2_conv.parameters(), "lr": self.base_lr * 10} ) params_list.append( {"params": self.net.scale3_conv.parameters(), "lr": self.base_lr * 10} ) params_list.append( {"params": self.net.scale4_conv.parameters(), "lr": self.base_lr * 10} ) if self.other_kwargs["midasproto"]: print("Using midas optimization protocol") opt = torch.optim.Adam( params_list, lr=self.base_lr, betas=(0.9, 0.999), weight_decay=self.other_kwargs["weight_decay"], ) sch = torch.optim.lr_scheduler.LambdaLR( opt, lambda x: pow(1.0 - x / self.epochs, 0.9) ) else: opt = torch.optim.SGD( params_list, lr=self.base_lr, momentum=0.9, weight_decay=self.other_kwargs["weight_decay"], ) sch = torch.optim.lr_scheduler.LambdaLR( opt, lambda x: pow(1.0 - x / self.epochs, 0.9) ) return [opt], [sch] def train_dataloader(self): return torch.utils.data.DataLoader( self.trainset, batch_size=self.batch_size, shuffle=True, num_workers=16, worker_init_fn=lambda x: random.seed(time.time() + x), ) def val_dataloader(self): return torch.utils.data.DataLoader( self.valset, batch_size=self.batch_size, shuffle=False, num_workers=16, ) def get_trainset(self, dset, augment=False, **kwargs): print(kwargs) if augment == True: mode = "train_x" else: mode = "train" print(mode) dset = get_dataset( dset, root=self.data_path, split="train", mode=mode, transform=self.train_transform, **kwargs ) self.num_classes = dset.num_class self.train_accuracy = pl.metrics.Accuracy() return dset def get_valset(self, dset, augment=False, **kwargs): self.val_accuracy = pl.metrics.Accuracy() self.val_iou = SegmentationMetric(self.num_classes) if augment == True: mode = "val_x" else: mode = "val" print(mode) return get_dataset( dset, root=self.data_path, split="val", mode=mode, transform=self.val_transform, **kwargs ) def get_criterion(self, **kwargs): return SegmentationLosses( se_loss=kwargs["se_loss"], aux=kwargs["aux"], nclass=self.num_classes, se_weight=kwargs["se_weight"], aux_weight=kwargs["aux_weight"], ignore_index=kwargs["ignore_index"], ) @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument( "--data_path", type=str, help="path where dataset is stored" ) parser.add_argument( "--dataset", choices=get_available_datasets(), default="ade20k", help="dataset to train on", ) parser.add_argument( "--batch_size", type=int, default=16, help="size of the batches" ) parser.add_argument( "--base_lr", type=float, default=0.004, help="learning rate" ) parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum") parser.add_argument( "--weight_decay", type=float, default=1e-4, help="weight_decay" ) parser.add_argument( "--aux", action="store_true", default=False, help="Auxilary Loss" ) parser.add_argument( "--aux-weight", type=float, default=0.2, help="Auxilary loss weight (default: 0.2)", ) parser.add_argument( "--se-loss", action="store_true", default=False, help="Semantic Encoding Loss SE-loss", ) parser.add_argument( "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)" ) parser.add_argument( "--midasproto", action="store_true", default=False, help="midasprotocol" ) parser.add_argument( "--ignore_index", type=int, default=-1, help="numeric value of ignore label in gt", ) parser.add_argument( "--augment", action="store_true", default=False, help="Use extended augmentations", ) return parser