import os os.environ["MKL_NUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" os.environ["OMP_NUM_THREADS"] = "1" from os import path, makedirs, listdir import sys import numpy as np np.random.seed(1) import random random.seed(1) import torch from torch import nn from torch.backends import cudnn from torch.utils.data import Dataset from torch.utils.data import DataLoader import torch.optim.lr_scheduler as lr_scheduler from apex import amp from adamw import AdamW from losses import dice_round, ComboLoss import pandas as pd from tqdm import tqdm import timeit import cv2 from zoo.models import SeNet154_Unet_Loc from imgaug import augmenters as iaa from utils import * from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score import gc cv2.setNumThreads(0) cv2.ocl.setUseOpenCL(False) train_dirs = ['train', 'tier3'] models_folder = 'weights' input_shape = (480, 480) all_files = [] for d in train_dirs: for f in sorted(listdir(path.join(d, 'images'))): if '_pre_disaster.png' in f: all_files.append(path.join(d, 'images', f)) class TrainData(Dataset): def __init__(self, train_idxs): super().__init__() self.train_idxs = train_idxs self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2) def __len__(self): return len(self.train_idxs) def __getitem__(self, idx): _idx = self.train_idxs[idx] fn = all_files[_idx] img = cv2.imread(fn, cv2.IMREAD_COLOR) if random.random() > 0.96: img = cv2.imread(fn.replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_COLOR) msk0 = cv2.imread(fn.replace('/images/', '/masks/'), cv2.IMREAD_UNCHANGED) if random.random() > 0.6: img = img[::-1, ...] msk0 = msk0[::-1, ...] if random.random() > 0.1: rot = random.randrange(4) if rot > 0: img = np.rot90(img, k=rot) msk0 = np.rot90(msk0, k=rot) if random.random() > 0.7: shift_pnt = (random.randint(-320, 320), random.randint(-320, 320)) img = shift_image(img, shift_pnt) msk0 = shift_image(msk0, shift_pnt) if random.random() > 0.4: rot_pnt = (img.shape[0] // 2 + random.randint(-320, 320), img.shape[1] // 2 + random.randint(-320, 320)) scale = 0.9 + random.random() * 0.2 angle = random.randint(0, 20) - 10 if (angle != 0) or (scale != 1): img = rotate_image(img, angle, scale, rot_pnt) msk0 = rotate_image(msk0, angle, scale, rot_pnt) crop_size = input_shape[0] if random.random() > 0.2: crop_size = random.randint(int(input_shape[0] / 1.1), int(input_shape[0] / 0.9)) bst_x0 = random.randint(0, img.shape[1] - crop_size) bst_y0 = random.randint(0, img.shape[0] - crop_size) bst_sc = -1 try_cnt = random.randint(1, 5) for i in range(try_cnt): x0 = random.randint(0, img.shape[1] - crop_size) y0 = random.randint(0, img.shape[0] - crop_size) _sc = msk0[y0:y0+crop_size, x0:x0+crop_size].sum() if _sc > bst_sc: bst_sc = _sc bst_x0 = x0 bst_y0 = y0 x0 = bst_x0 y0 = bst_y0 img = img[y0:y0+crop_size, x0:x0+crop_size, :] msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size] if crop_size != input_shape[0]: img = cv2.resize(img, input_shape, interpolation=cv2.INTER_LINEAR) msk0 = cv2.resize(msk0, input_shape, interpolation=cv2.INTER_LINEAR) if random.random() > 0.95: img = shift_channels(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) if random.random() > 0.9597: img = change_hsv(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) if random.random() > 0.92: if random.random() > 0.92: img = clahe(img) elif random.random() > 0.92: img = gauss_noise(img) elif random.random() > 0.92: img = cv2.blur(img, (3, 3)) elif random.random() > 0.92: if random.random() > 0.92: img = saturation(img, 0.9 + random.random() * 0.2) elif random.random() > 0.92: img = brightness(img, 0.9 + random.random() * 0.2) elif random.random() > 0.92: img = contrast(img, 0.9 + random.random() * 0.2) if random.random() > 0.95: el_det = self.elastic.to_deterministic() img = el_det.augment_image(img) msk = msk0[..., np.newaxis] msk = (msk > 127) * 1 img = preprocess_inputs(img) img = torch.from_numpy(img.transpose((2, 0, 1))).float() msk = torch.from_numpy(msk.transpose((2, 0, 1))).long() sample = {'img': img, 'msk': msk, 'fn': fn} return sample class ValData(Dataset): def __init__(self, image_idxs): super().__init__() self.image_idxs = image_idxs def __len__(self): return len(self.image_idxs) def __getitem__(self, idx): _idx = self.image_idxs[idx] fn = all_files[_idx] img = cv2.imread(fn, cv2.IMREAD_COLOR) msk0 = cv2.imread(fn.replace('/images/', '/masks/'), cv2.IMREAD_UNCHANGED) msk = msk0[..., np.newaxis] msk = (msk > 127) * 1 img = preprocess_inputs(img) img = torch.from_numpy(img.transpose((2, 0, 1))).float() msk = torch.from_numpy(msk.transpose((2, 0, 1))).long() sample = {'img': img, 'msk': msk, 'fn': fn} return sample def validate(net, data_loader): dices0 = [] _thr = 0.5 with torch.no_grad(): for i, sample in enumerate(tqdm(data_loader)): msks = sample["msk"].numpy() imgs = sample["img"].cuda(non_blocking=True) out = model(imgs) msk_pred = torch.sigmoid(out[:, 0, ...]).cpu().numpy() for j in range(msks.shape[0]): dices0.append(dice(msks[j, 0], msk_pred[j] > _thr)) d0 = np.mean(dices0) print("Val Dice: {}".format(d0)) return d0 def evaluate_val(data_val, best_score, model, snapshot_name, current_epoch): model = model.eval() d = validate(model, data_loader=data_val) if d > best_score: torch.save({ 'epoch': current_epoch + 1, 'state_dict': model.state_dict(), 'best_score': d, }, path.join(models_folder, snapshot_name + '_best')) best_score = d print("score: {}\tscore_best: {}".format(d, best_score)) return best_score def train_epoch(current_epoch, seg_loss, model, optimizer, scheduler, train_data_loader): losses = AverageMeter() dices = AverageMeter() iterator = tqdm(train_data_loader) model.train() for i, sample in enumerate(iterator): imgs = sample["img"].cuda(non_blocking=True) msks = sample["msk"].cuda(non_blocking=True) out = model(imgs) loss = seg_loss(out, msks) with torch.no_grad(): _probs = torch.sigmoid(out[:, 0, ...]) dice_sc = 1 - dice_round(_probs, msks[:, 0, ...]) losses.update(loss.item(), imgs.size(0)) dices.update(dice_sc, imgs.size(0)) iterator.set_description( "epoch: {}; lr {:.7f}; Loss {loss.val:.4f} ({loss.avg:.4f}); Dice {dice.val:.4f} ({dice.avg:.4f})".format( current_epoch, scheduler.get_lr()[-1], loss=losses, dice=dices)) optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 0.999) optimizer.step() scheduler.step(current_epoch) print("epoch: {}; lr {:.7f}; Loss {loss.avg:.4f}; Dice {dice.avg:.4f}".format( current_epoch, scheduler.get_lr()[-1], loss=losses, dice=dices)) if __name__ == '__main__': t0 = timeit.default_timer() makedirs(models_folder, exist_ok=True) seed = int(sys.argv[1]) # vis_dev = sys.argv[2] # os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' # os.environ["CUDA_VISIBLE_DEVICES"] = vis_dev cudnn.benchmark = True batch_size = 14 val_batch_size = 4 snapshot_name = 'se154_loc_{}_1'.format(seed) train_idxs, val_idxs = train_test_split(np.arange(len(all_files)), test_size=0.1, random_state=seed) np.random.seed(seed + 321) random.seed(seed + 321) steps_per_epoch = len(train_idxs) // batch_size validation_steps = len(val_idxs) // val_batch_size print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps) data_train = TrainData(train_idxs) val_train = ValData(val_idxs) train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=6, shuffle=True, pin_memory=False, drop_last=True) val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=6, shuffle=False, pin_memory=False) model = SeNet154_Unet_Loc().cuda() params = model.parameters() optimizer = AdamW(params, lr=0.00015, weight_decay=1e-6) model, optimizer = amp.initialize(model, optimizer, opt_level="O1") scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[3, 7, 11, 15, 19, 23, 27, 33, 41, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190], gamma=0.5) model = nn.DataParallel(model).cuda() seg_loss = ComboLoss({'dice': 1.0, 'focal': 14.0}, per_image=False).cuda() #True best_score = 0 _cnt = -1 torch.cuda.empty_cache() for epoch in range(30): train_epoch(epoch, seg_loss, model, optimizer, scheduler, train_data_loader) torch.cuda.empty_cache() best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch) elapsed = timeit.default_timer() - t0 print('Time: {:.3f} min'.format(elapsed / 60))