import os os.environ["MKL_NUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" os.environ["OMP_NUM_THREADS"] = "1" from os import path, makedirs, listdir from pathlib import Path import sys import numpy as np np.random.seed(1) import random random.seed(1) import torch from torch import nn from torch.backends import cudnn from torch.utils.data import Dataset from torch.utils.data import DataLoader import torch.optim.lr_scheduler as lr_scheduler from warmup_scheduler import GradualWarmupScheduler from torch import amp from adamw import AdamW from losses import dice_round, ComboLoss import pandas as pd from tqdm import tqdm import timeit import cv2 from zoo.models import SeResNext50_Unet_Loc from imgaug import augmenters as iaa from utils import * from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score from torch.nn.parallel import DistributedDataParallel import gc from utils import all_gather import wandb cv2.setNumThreads(0) cv2.ocl.setUseOpenCL(False) train_dirs = ['train','tier3' ] ROOT = './xview2' models_folder = 'weights' input_shape = (512, 512) all_files = [] for d in train_dirs: for f in sorted(listdir(path.join(ROOT,d, 'images'))): if '_pre_disaster.png' in f: all_files.append(path.join(ROOT,d, 'images', f)) class TrainData(Dataset): def __init__(self, train_idxs): super().__init__() self.train_idxs = train_idxs self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2) def __len__(self): return len(self.train_idxs) def __getitem__(self, idx): _idx = self.train_idxs[idx] fn = all_files[_idx] img = cv2.imread(fn, cv2.IMREAD_COLOR) if random.random() > 0.985: img = cv2.imread(fn.replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_COLOR) msk0 = cv2.imread(fn.replace('/images/', '/masks/'), cv2.IMREAD_UNCHANGED) if random.random() > 0.5: img = img[::-1, ...] msk0 = msk0[::-1, ...] if random.random() > 0.05: rot = random.randrange(4) if rot > 0: img = np.rot90(img, k=rot) msk0 = np.rot90(msk0, k=rot) if random.random() > 0.9: shift_pnt = (random.randint(-320, 320), random.randint(-320, 320)) img = shift_image(img, shift_pnt) msk0 = shift_image(msk0, shift_pnt) if random.random() > 0.9: rot_pnt = (img.shape[0] // 2 + random.randint(-320, 320), img.shape[1] // 2 + random.randint(-320, 320)) scale = 0.9 + random.random() * 0.2 angle = random.randint(0, 20) - 10 if (angle != 0) or (scale != 1): img = rotate_image(img, angle, scale, rot_pnt) msk0 = rotate_image(msk0, angle, scale, rot_pnt) crop_size = input_shape[0] if random.random() > 0.3: crop_size = random.randint(int(input_shape[0] / 1.1), int(input_shape[0] / 0.9)) bst_x0 = random.randint(0, img.shape[1] - crop_size) bst_y0 = random.randint(0, img.shape[0] - crop_size) bst_sc = -1 try_cnt = random.randint(1, 5) for i in range(try_cnt): x0 = random.randint(0, img.shape[1] - crop_size) y0 = random.randint(0, img.shape[0] - crop_size) _sc = msk0[y0:y0+crop_size, x0:x0+crop_size].sum() if _sc > bst_sc: bst_sc = _sc bst_x0 = x0 bst_y0 = y0 x0 = bst_x0 y0 = bst_y0 img = img[y0:y0+crop_size, x0:x0+crop_size, :] msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size] if crop_size != input_shape[0]: img = cv2.resize(img, input_shape, interpolation=cv2.INTER_LINEAR) msk0 = cv2.resize(msk0, input_shape, interpolation=cv2.INTER_LINEAR) if random.random() > 0.99: img = shift_channels(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) if random.random() > 0.99: img = change_hsv(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) if random.random() > 0.99: if random.random() > 0.99: img = clahe(img) elif random.random() > 0.99: img = gauss_noise(img) elif random.random() > 0.99: img = cv2.blur(img, (3, 3)) elif random.random() > 0.99: if random.random() > 0.99: img = saturation(img, 0.9 + random.random() * 0.2) elif random.random() > 0.99: img = brightness(img, 0.9 + random.random() * 0.2) elif random.random() > 0.99: img = contrast(img, 0.9 + random.random() * 0.2) if random.random() > 0.999: el_det = self.elastic.to_deterministic() img = el_det.augment_image(img) msk = msk0[..., np.newaxis] msk = (msk > 127) * 1 img = preprocess_inputs(img) img = torch.from_numpy(img.transpose((2, 0, 1))).float() #print(msk.shape) msk = torch.from_numpy(msk.transpose((2, 0, 1))).long() sample = {'img': img, 'msk': msk, 'fn': fn} return sample class ValData(Dataset): def __init__(self, image_idxs): super().__init__() self.image_idxs = image_idxs def __len__(self): return len(self.image_idxs) def __getitem__(self, idx): _idx = self.image_idxs[idx] fn = all_files[_idx] img = cv2.imread(fn, cv2.IMREAD_COLOR) msk0 = cv2.imread(fn.replace('/images/', '/masks/'), cv2.IMREAD_UNCHANGED) msk = msk0[..., np.newaxis] msk = (msk > 127) * 1 img = preprocess_inputs(img) img = torch.from_numpy(img.transpose((2, 0, 1))).float() msk = torch.from_numpy(msk.transpose((2, 0, 1))).long() sample = {'img': img, 'msk': msk, 'fn': fn} return sample def validate(net, data_loader): dices0 = [] _thr = 0.5 with torch.no_grad(): for i, sample in enumerate(tqdm(data_loader)): msks = sample["msk"].numpy() imgs = sample["img"].cuda(non_blocking=True) out = model(imgs) msk_pred = torch.sigmoid(out[:, 0, ...]).cpu().numpy() for j in range(msks.shape[0]): dices0.append(dice(msks[j, 0], msk_pred[j] > _thr)) dices0_gathered = all_gather(dices0) dices0 = np.concatenate(dices0_gathered) d0 = np.mean(dices0) print("Val Dice: {}".format(d0)) if wandb.run is not None: wandb.log({"val_dice":d0}) return d0 def evaluate_val(data_val, best_score, model, snapshot_name, current_epoch,save=False): model = model.eval() d = validate(model, data_loader=data_val) if save and d > best_score: torch.save({ 'epoch': current_epoch + 1, 'state_dict': model.state_dict(), 'best_score': d, }, path.join(models_folder, snapshot_name + '_best')) best_score = d print("score: {}\tscore_best: {}".format(d, best_score)) return best_score def train_epoch(current_epoch, seg_loss, model, optimizer, scheduler, train_data_loader,scaler): losses = AverageMeter() dices = AverageMeter() iterator = tqdm(train_data_loader) model.train() for i, sample in enumerate(iterator): imgs = sample["img"].cuda(non_blocking=True) msks = sample["msk"].cuda(non_blocking=True) with torch.amp.autocast(device_type='cuda', dtype=torch.float32): out = model(imgs) loss = seg_loss(out, msks) with torch.no_grad(): _probs = torch.sigmoid(out[:, 0, ...]) dice_sc = 1 - dice_round(_probs, msks[:, 0, ...]) losses.update(loss.item(), imgs.size(0)) dices.update(dice_sc, imgs.size(0)) iterator.set_description( "epoch: {}; lr {:.7f}; Loss {loss.val:.4f} ({loss.avg:.4f}); Dice {dice.val:.4f} ({dice.avg:.4f})".format( current_epoch, scheduler.get_lr()[-1], loss=losses, dice=dices)) if wandb.run is not None and i%20 == 0: wandb.log( dict( epoch=current_epoch, lr= float(scheduler.get_lr()[-1]), loss_avg=losses.avg, loss_val=losses.val, dice=dices.avg, ) ) optimizer.zero_grad() # loss.backward() scaler.scale(loss).backward() scaler.unscale_(optimizer) # with amp.scale_loss(loss, optimizer) as scaled_loss: # scaled_loss.backward() #torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.1) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.1) scaler.step(optimizer) scaler.update() #optimizer.step() scheduler.step(current_epoch) print("epoch: {}; lr {:.7f}; Loss {loss.avg:.4f}; Dice {dice.avg:.4f}".format( current_epoch, scheduler.get_lr()[-1], loss=losses, dice=dices)) from torch.utils.data import DistributedSampler import torch.distributed as dist from climax import ClimaXLegacy from argparse import ArgumentParser if __name__ == '__main__': t0 = timeit.default_timer() parser = ArgumentParser() parser.add_argument('seed',type=int) parser.add_argument('--val',action='store_true') parser.add_argument('--resume',action='store_true') parser.add_argument('--eval-output',action='store_true') parser.add_argument('--model',default='climax') args = parser.parse_args() makedirs(models_folder, exist_ok=True) seed = args.seed # vis_dev = sys.argv[2] # os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' # os.environ["CUDA_VISIBLE_DEVICES"] = vis_dev local_rank = int(os.environ['LOCAL_RANK']) cudnn.benchmark = True batch_size = 2 val_batch_size = 4 dist.init_process_group(backend='nccl') torch.cuda.set_device(local_rank) if local_rank == 0: wandb.init() snapshot_name = '{}_loc_{}_0'.format(args.model,seed) train_idxs, val_idxs = train_test_split(np.arange(len(all_files)), test_size=0.1, random_state=seed) np.random.seed(seed+123) random.seed(seed+123) if not os.path.isfile('folds.pth'): torch.save( dict( train_idxs =train_idxs, val_idxs=val_idxs, ),'folds.pth' ) else: folds = torch.load('folds.pth') train_idxs = folds['train_idxs'] val_idxs = folds['val_idxs'] steps_per_epoch = len(train_idxs) // batch_size validation_steps = len(val_idxs) // val_batch_size print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps) data_train = TrainData(train_idxs) val_train = ValData(val_idxs) sampler_train = DistributedSampler(data_train,shuffle=True) sampler_val = DistributedSampler(val_train,shuffle=False) train_data_loader = DataLoader(data_train,sampler=sampler_train, batch_size=batch_size, num_workers=6, pin_memory=False, drop_last=True) val_data_loader = DataLoader(val_train,sampler=sampler_val, batch_size=val_batch_size, num_workers=6, pin_memory=False) variables = ['R','G','B'] model = ClimaXLegacy(img_size=input_shape,patch_size=16,default_vars=variables,pretrained='5.625deg.ckpt', upsampling_steps=[{"step_scale_factor": 2, "new_channel_dim" : 1024, "feature_dim": 2048}, {"step_scale_factor": 2, "new_channel_dim" : 512, "feature_dim": 1024}, {"step_scale_factor": 2, "new_channel_dim" : 256, "feature_dim": 512}, {"step_scale_factor": 2, "new_channel_dim" : 128, "feature_dim": 256},], feature_extractor_type="res-net",out_dim=1).cuda() params = model.parameters() if args.resume: snap_to_load = '{}_loc_{}_0_best'.format(args.model,seed) print("=> loading checkpoint '{}'".format(snap_to_load)) checkpoint = torch.load(path.join(models_folder, snap_to_load), map_location='cpu') loaded_dict = checkpoint['state_dict'] if 'module.' in list(loaded_dict.keys())[0]: loaded_dict = {k[7:]:v for k,v in loaded_dict.items()} sd=model.state_dict() for k in sd: if k in loaded_dict and sd[k].size() == loaded_dict[k].size(): sd[k] = loaded_dict[k] elif 'pos_embed' in k: new_k = int(np.sqrt( sd[k].shape[1])) old_k = int(np.sqrt( loaded_dict[k].shape[1])) sd[k] = nn.functional.interpolate(loaded_dict[k].view(1,old_k,old_k,1024).permute(0,3,1,2),(new_k,new_k),mode='bicubic').permute(0,2,3,1).view(*sd[k].shape) print(sd[k].shape,loaded_dict[k].shape) else: print(k) model.load_state_dict(sd,strict=True) print("loaded checkpoint '{}' (epoch {}, best_score {})" .format(snap_to_load, checkpoint['epoch'], checkpoint['best_score'])) del loaded_dict del checkpoint gc.collect() torch.cuda.empty_cache() model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = DistributedDataParallel(model) model = nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = AdamW(params, lr=0.00015*0.1, weight_decay=1e-6) scaler = torch.cuda.amp.GradScaler() #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") #scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[15, 29, 43, 53, 65, 80, 90, 100, 110, 130, 150, 170, 180, 190], gamma=0.5) scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[ 140, 180, ], gamma=0.1) # scheduler = GradualWarmupScheduler( # optimizer, # multiplier=1, # total_epoch=20, # after_scheduler=scheduler, # ) seg_loss = ComboLoss({'dice': 1.0, 'focal': 10.0}, per_image=False).cuda() best_score = 0 _cnt = -1 torch.cuda.empty_cache() if args.val: best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, -1) elif args.eval_output: pred_folder = f'pred_loc_val_{args.model}' Path(pred_folder).mkdir(exist_ok=True) with torch.no_grad(): indices = list([x for x in val_idxs if x % dist.get_world_size() == dist.get_rank()]) for idx in tqdm(indices): fn = all_files[idx] f = fn.split('/')[-1] img = cv2.imread(fn, cv2.IMREAD_COLOR) img = preprocess_inputs(img) inp = [] inp.append(img) inp.append(img[::-1, ::-1, ...]) inp = np.asarray(inp, dtype='float') inp = torch.from_numpy(inp.transpose((0, 3, 1, 2))).float() inp = torch.tensor(inp).cuda() pred = [] msk = model(inp) msk = torch.sigmoid(msk) msk = msk.cpu().numpy() pred.append(msk[0, ...]) pred.append(msk[1, :, ::-1, ::-1]) pred_full = np.asarray(pred).mean(axis=0) msk = pred_full * 255 msk = msk.astype('uint8').transpose(1, 2, 0) cv2.imwrite(path.join(pred_folder, '{0}.png'.format(f.replace('.png', '_part1.png'))), msk[..., 0], [cv2.IMWRITE_PNG_COMPRESSION, 9]) else: # train for epoch in range(200): train_epoch(epoch, seg_loss, model, optimizer, scheduler, train_data_loader,scaler) if epoch % 2 == 0: _cnt += 1 torch.cuda.empty_cache() best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch,True) elapsed = timeit.default_timer() - t0 print('Time: {:.3f} min'.format(elapsed / 60))