climax-xview / trainclimax_loc copy.py
jacklishufan's picture
init commit
844f7c0
import os
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
from os import path, makedirs, listdir
from pathlib import Path
import sys
import numpy as np
np.random.seed(1)
import random
random.seed(1)
import torch
from torch import nn
from torch.backends import cudnn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
from torch import amp
from adamw import AdamW
from losses import dice_round, ComboLoss
import pandas as pd
from tqdm import tqdm
import timeit
import cv2
from zoo.models import SeResNext50_Unet_Loc
from imgaug import augmenters as iaa
from utils import *
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from torch.nn.parallel import DistributedDataParallel
import gc
from utils import all_gather
import wandb
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
train_dirs = ['train','tier3' ]
ROOT = './xview2'
models_folder = 'weights'
input_shape = (512, 512)
all_files = []
for d in train_dirs:
for f in sorted(listdir(path.join(ROOT,d, 'images'))):
if '_pre_disaster.png' in f:
all_files.append(path.join(ROOT,d, 'images', f))
class TrainData(Dataset):
def __init__(self, train_idxs):
super().__init__()
self.train_idxs = train_idxs
self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2)
def __len__(self):
return len(self.train_idxs)
def __getitem__(self, idx):
_idx = self.train_idxs[idx]
fn = all_files[_idx]
img = cv2.imread(fn, cv2.IMREAD_COLOR)
if random.random() > 0.985:
img = cv2.imread(fn.replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_COLOR)
msk0 = cv2.imread(fn.replace('/images/', '/masks/'), cv2.IMREAD_UNCHANGED)
if random.random() > 0.5:
img = img[::-1, ...]
msk0 = msk0[::-1, ...]
if random.random() > 0.05:
rot = random.randrange(4)
if rot > 0:
img = np.rot90(img, k=rot)
msk0 = np.rot90(msk0, k=rot)
if random.random() > 0.9:
shift_pnt = (random.randint(-320, 320), random.randint(-320, 320))
img = shift_image(img, shift_pnt)
msk0 = shift_image(msk0, shift_pnt)
if random.random() > 0.9:
rot_pnt = (img.shape[0] // 2 + random.randint(-320, 320), img.shape[1] // 2 + random.randint(-320, 320))
scale = 0.9 + random.random() * 0.2
angle = random.randint(0, 20) - 10
if (angle != 0) or (scale != 1):
img = rotate_image(img, angle, scale, rot_pnt)
msk0 = rotate_image(msk0, angle, scale, rot_pnt)
crop_size = input_shape[0]
if random.random() > 0.3:
crop_size = random.randint(int(input_shape[0] / 1.1), int(input_shape[0] / 0.9))
bst_x0 = random.randint(0, img.shape[1] - crop_size)
bst_y0 = random.randint(0, img.shape[0] - crop_size)
bst_sc = -1
try_cnt = random.randint(1, 5)
for i in range(try_cnt):
x0 = random.randint(0, img.shape[1] - crop_size)
y0 = random.randint(0, img.shape[0] - crop_size)
_sc = msk0[y0:y0+crop_size, x0:x0+crop_size].sum()
if _sc > bst_sc:
bst_sc = _sc
bst_x0 = x0
bst_y0 = y0
x0 = bst_x0
y0 = bst_y0
img = img[y0:y0+crop_size, x0:x0+crop_size, :]
msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size]
if crop_size != input_shape[0]:
img = cv2.resize(img, input_shape, interpolation=cv2.INTER_LINEAR)
msk0 = cv2.resize(msk0, input_shape, interpolation=cv2.INTER_LINEAR)
if random.random() > 0.99:
img = shift_channels(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))
if random.random() > 0.99:
img = change_hsv(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))
if random.random() > 0.99:
if random.random() > 0.99:
img = clahe(img)
elif random.random() > 0.99:
img = gauss_noise(img)
elif random.random() > 0.99:
img = cv2.blur(img, (3, 3))
elif random.random() > 0.99:
if random.random() > 0.99:
img = saturation(img, 0.9 + random.random() * 0.2)
elif random.random() > 0.99:
img = brightness(img, 0.9 + random.random() * 0.2)
elif random.random() > 0.99:
img = contrast(img, 0.9 + random.random() * 0.2)
if random.random() > 0.999:
el_det = self.elastic.to_deterministic()
img = el_det.augment_image(img)
msk = msk0[..., np.newaxis]
msk = (msk > 127) * 1
img = preprocess_inputs(img)
img = torch.from_numpy(img.transpose((2, 0, 1))).float()
#print(msk.shape)
msk = torch.from_numpy(msk.transpose((2, 0, 1))).long()
sample = {'img': img, 'msk': msk, 'fn': fn}
return sample
class ValData(Dataset):
def __init__(self, image_idxs):
super().__init__()
self.image_idxs = image_idxs
def __len__(self):
return len(self.image_idxs)
def __getitem__(self, idx):
_idx = self.image_idxs[idx]
fn = all_files[_idx]
img = cv2.imread(fn, cv2.IMREAD_COLOR)
msk0 = cv2.imread(fn.replace('/images/', '/masks/'), cv2.IMREAD_UNCHANGED)
msk = msk0[..., np.newaxis]
msk = (msk > 127) * 1
img = preprocess_inputs(img)
img = torch.from_numpy(img.transpose((2, 0, 1))).float()
msk = torch.from_numpy(msk.transpose((2, 0, 1))).long()
sample = {'img': img, 'msk': msk, 'fn': fn}
return sample
def validate(net, data_loader):
dices0 = []
_thr = 0.5
with torch.no_grad():
for i, sample in enumerate(tqdm(data_loader)):
msks = sample["msk"].numpy()
imgs = sample["img"].cuda(non_blocking=True)
out = model(imgs)
msk_pred = torch.sigmoid(out[:, 0, ...]).cpu().numpy()
for j in range(msks.shape[0]):
dices0.append(dice(msks[j, 0], msk_pred[j] > _thr))
dices0_gathered = all_gather(dices0)
dices0 = np.concatenate(dices0_gathered)
d0 = np.mean(dices0)
print("Val Dice: {}".format(d0))
if wandb.run is not None:
wandb.log({"val_dice":d0})
return d0
def evaluate_val(data_val, best_score, model, snapshot_name, current_epoch,save=False):
model = model.eval()
d = validate(model, data_loader=data_val)
if save and d > best_score:
torch.save({
'epoch': current_epoch + 1,
'state_dict': model.state_dict(),
'best_score': d,
}, path.join(models_folder, snapshot_name + '_best'))
best_score = d
print("score: {}\tscore_best: {}".format(d, best_score))
return best_score
def train_epoch(current_epoch, seg_loss, model, optimizer, scheduler, train_data_loader,scaler):
losses = AverageMeter()
dices = AverageMeter()
iterator = tqdm(train_data_loader)
model.train()
for i, sample in enumerate(iterator):
imgs = sample["img"].cuda(non_blocking=True)
msks = sample["msk"].cuda(non_blocking=True)
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
out = model(imgs)
loss = seg_loss(out, msks)
with torch.no_grad():
_probs = torch.sigmoid(out[:, 0, ...])
dice_sc = 1 - dice_round(_probs, msks[:, 0, ...])
losses.update(loss.item(), imgs.size(0))
dices.update(dice_sc, imgs.size(0))
iterator.set_description(
"epoch: {}; lr {:.7f}; Loss {loss.val:.4f} ({loss.avg:.4f}); Dice {dice.val:.4f} ({dice.avg:.4f})".format(
current_epoch, scheduler.get_lr()[-1], loss=losses, dice=dices))
if wandb.run is not None and i%20 == 0:
wandb.log(
dict(
epoch=current_epoch,
lr= float(scheduler.get_lr()[-1]),
loss_avg=losses.avg,
loss_val=losses.val,
dice=dices.avg,
)
)
optimizer.zero_grad()
# loss.backward()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
# with amp.scale_loss(loss, optimizer) as scaled_loss:
# scaled_loss.backward()
#torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.1)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.1)
scaler.step(optimizer)
scaler.update()
#optimizer.step()
scheduler.step(current_epoch)
print("epoch: {}; lr {:.7f}; Loss {loss.avg:.4f}; Dice {dice.avg:.4f}".format(
current_epoch, scheduler.get_lr()[-1], loss=losses, dice=dices))
from torch.utils.data import DistributedSampler
import torch.distributed as dist
from climax import ClimaXLegacy
from argparse import ArgumentParser
if __name__ == '__main__':
t0 = timeit.default_timer()
parser = ArgumentParser()
parser.add_argument('seed',type=int)
parser.add_argument('--val',action='store_true')
parser.add_argument('--resume',action='store_true')
parser.add_argument('--eval-output',action='store_true')
parser.add_argument('--model',default='climax')
args = parser.parse_args()
makedirs(models_folder, exist_ok=True)
seed = args.seed
# vis_dev = sys.argv[2]
# os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# os.environ["CUDA_VISIBLE_DEVICES"] = vis_dev
local_rank = int(os.environ['LOCAL_RANK'])
cudnn.benchmark = True
batch_size = 15
val_batch_size = 4
dist.init_process_group(backend='nccl')
torch.cuda.set_device(local_rank)
if local_rank == 0:
wandb.init()
snapshot_name = '{}_loc_{}_0'.format(args.model,seed)
train_idxs, val_idxs = train_test_split(np.arange(len(all_files)), test_size=0.1, random_state=seed)
np.random.seed(seed+123)
random.seed(seed+123)
if not os.path.isfile('folds.pth'):
torch.save(
dict(
train_idxs =train_idxs,
val_idxs=val_idxs,
),'folds.pth'
)
else:
folds = torch.load('folds.pth')
train_idxs = folds['train_idxs']
val_idxs = folds['val_idxs']
steps_per_epoch = len(train_idxs) // batch_size
validation_steps = len(val_idxs) // val_batch_size
print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps)
data_train = TrainData(train_idxs)
val_train = ValData(val_idxs)
sampler_train = DistributedSampler(data_train,shuffle=True)
sampler_val = DistributedSampler(val_train,shuffle=False)
train_data_loader = DataLoader(data_train,sampler=sampler_train, batch_size=batch_size, num_workers=6, pin_memory=False, drop_last=True)
val_data_loader = DataLoader(val_train,sampler=sampler_val, batch_size=val_batch_size, num_workers=6, pin_memory=False)
variables = ['R','G','B']
model = ClimaXLegacy(img_size=[512,512],patch_size=16,default_vars=variables,pretrained='5.625deg.ckpt').cuda()
params = model.parameters()
if args.resume:
snap_to_load = '{}_loc_{}_0_best'.format(args.model,seed)
print("=> loading checkpoint '{}'".format(snap_to_load))
checkpoint = torch.load(path.join(models_folder, snap_to_load), map_location='cpu')
loaded_dict = checkpoint['state_dict']
if 'module.' in list(loaded_dict.keys())[0]:
loaded_dict = {k[7:]:v for k,v in loaded_dict.items()}
model.load_state_dict(loaded_dict,strict=True)
print("loaded checkpoint '{}' (epoch {}, best_score {})"
.format(snap_to_load, checkpoint['epoch'], checkpoint['best_score']))
del loaded_dict
del checkpoint
gc.collect()
torch.cuda.empty_cache()
model = DistributedDataParallel(model)
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
optimizer = AdamW(params, lr=0.00015*0.1, weight_decay=1e-6)
scaler = torch.cuda.amp.GradScaler()
#model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
#scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[15, 29, 43, 53, 65, 80, 90, 100, 110, 130, 150, 170, 180, 190], gamma=0.5)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[ 140, 180, ], gamma=0.1)
seg_loss = ComboLoss({'dice': 1.0, 'focal': 10.0}, per_image=False).cuda()
best_score = 0
_cnt = -1
torch.cuda.empty_cache()
if args.val:
best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, -1)
elif args.eval_output:
pred_folder = f'pred_loc_val_{args.model}'
Path(pred_folder).mkdir(exist_ok=True)
with torch.no_grad():
indices = list([x for x in val_idxs if x % dist.get_world_size() == dist.get_rank()])
for idx in tqdm(indices):
fn = all_files[idx]
f = fn.split('/')[-1]
img = cv2.imread(fn, cv2.IMREAD_COLOR)
img = preprocess_inputs(img)
inp = []
inp.append(img)
inp.append(img[::-1, ::-1, ...])
inp = np.asarray(inp, dtype='float')
inp = torch.from_numpy(inp.transpose((0, 3, 1, 2))).float()
inp = torch.tensor(inp).cuda()
pred = []
msk = model(inp)
msk = torch.sigmoid(msk)
msk = msk.cpu().numpy()
pred.append(msk[0, ...])
pred.append(msk[1, :, ::-1, ::-1])
pred_full = np.asarray(pred).mean(axis=0)
msk = pred_full * 255
msk = msk.astype('uint8').transpose(1, 2, 0)
cv2.imwrite(path.join(pred_folder, '{0}.png'.format(f.replace('.png', '_part1.png'))), msk[..., 0], [cv2.IMWRITE_PNG_COMPRESSION, 9])
else:
# train
for epoch in range(200):
train_epoch(epoch, seg_loss, model, optimizer, scheduler, train_data_loader,scaler)
if epoch % 2 == 0:
_cnt += 1
torch.cuda.empty_cache()
best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch,True)
elapsed = timeit.default_timer() - t0
print('Time: {:.3f} min'.format(elapsed / 60))