|
import os |
|
os.environ["MKL_NUM_THREADS"] = "2" |
|
os.environ["NUMEXPR_NUM_THREADS"] = "2" |
|
os.environ["OMP_NUM_THREADS"] = "2" |
|
|
|
from os import path, makedirs, listdir |
|
import sys |
|
import numpy as np |
|
np.random.seed(1) |
|
import random |
|
random.seed(1) |
|
|
|
import torch |
|
from torch import nn |
|
from torch.backends import cudnn |
|
from torch.utils.data import Dataset |
|
from torch.utils.data import DataLoader |
|
import torch.optim.lr_scheduler as lr_scheduler |
|
|
|
from torch import amp |
|
|
|
from adamw import AdamW |
|
from losses import dice_round, ComboLoss |
|
|
|
import pandas as pd |
|
from tqdm import tqdm |
|
import timeit |
|
import cv2 |
|
|
|
from zoo.models import SeResNext50_Unet_Double |
|
|
|
from imgaug import augmenters as iaa |
|
|
|
from utils import * |
|
|
|
from skimage.morphology import square, dilation |
|
|
|
from sklearn.model_selection import train_test_split |
|
|
|
from sklearn.metrics import accuracy_score |
|
from torch.nn.parallel import DistributedDataParallel |
|
import gc |
|
from utils import all_gather |
|
cv2.setNumThreads(0) |
|
cv2.ocl.setUseOpenCL(False) |
|
|
|
train_dirs = ['train','tier3' ] |
|
ROOT = './xview2' |
|
models_folder = 'weights' |
|
|
|
input_shape = (512, 512) |
|
|
|
|
|
all_files = [] |
|
for d in train_dirs: |
|
for f in sorted(listdir(path.join(ROOT,d, 'images'))): |
|
if '_pre_disaster.png' in f: |
|
all_files.append(path.join(ROOT,d, 'images', f)) |
|
|
|
|
|
class TrainData(Dataset): |
|
def __init__(self, train_idxs): |
|
super().__init__() |
|
self.train_idxs = train_idxs |
|
self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2) |
|
|
|
def __len__(self): |
|
return len(self.train_idxs) |
|
|
|
def __getitem__(self, idx): |
|
_idx = self.train_idxs[idx] |
|
|
|
fn = all_files[_idx] |
|
|
|
img = cv2.imread(fn, cv2.IMREAD_COLOR) |
|
img2 = cv2.imread(fn.replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_COLOR) |
|
|
|
msk0 = cv2.imread(fn.replace('/images/', '/masks/'), cv2.IMREAD_UNCHANGED) |
|
lbl_msk1 = cv2.imread(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_UNCHANGED) |
|
msk1 = np.zeros_like(lbl_msk1) |
|
msk2 = np.zeros_like(lbl_msk1) |
|
msk3 = np.zeros_like(lbl_msk1) |
|
msk4 = np.zeros_like(lbl_msk1) |
|
msk2[lbl_msk1 == 2] = 255 |
|
msk3[lbl_msk1 == 3] = 255 |
|
msk4[lbl_msk1 == 4] = 255 |
|
msk1[lbl_msk1 == 1] = 255 |
|
|
|
if random.random() > 0.5: |
|
img = img[::-1, ...] |
|
img2 = img2[::-1, ...] |
|
msk0 = msk0[::-1, ...] |
|
msk1 = msk1[::-1, ...] |
|
msk2 = msk2[::-1, ...] |
|
msk3 = msk3[::-1, ...] |
|
msk4 = msk4[::-1, ...] |
|
|
|
if random.random() > 0.05: |
|
rot = random.randrange(4) |
|
if rot > 0: |
|
img = np.rot90(img, k=rot) |
|
img2 = np.rot90(img2, k=rot) |
|
msk0 = np.rot90(msk0, k=rot) |
|
msk1 = np.rot90(msk1, k=rot) |
|
msk2 = np.rot90(msk2, k=rot) |
|
msk3 = np.rot90(msk3, k=rot) |
|
msk4 = np.rot90(msk4, k=rot) |
|
|
|
if random.random() > 0.8: |
|
shift_pnt = (random.randint(-320, 320), random.randint(-320, 320)) |
|
img = shift_image(img, shift_pnt) |
|
img2 = shift_image(img2, shift_pnt) |
|
msk0 = shift_image(msk0, shift_pnt) |
|
msk1 = shift_image(msk1, shift_pnt) |
|
msk2 = shift_image(msk2, shift_pnt) |
|
msk3 = shift_image(msk3, shift_pnt) |
|
msk4 = shift_image(msk4, shift_pnt) |
|
|
|
if random.random() > 0.2: |
|
rot_pnt = (img.shape[0] // 2 + random.randint(-320, 320), img.shape[1] // 2 + random.randint(-320, 320)) |
|
scale = 0.9 + random.random() * 0.2 |
|
angle = random.randint(0, 20) - 10 |
|
if (angle != 0) or (scale != 1): |
|
img = rotate_image(img, angle, scale, rot_pnt) |
|
img2 = rotate_image(img2, angle, scale, rot_pnt) |
|
msk0 = rotate_image(msk0, angle, scale, rot_pnt) |
|
msk1 = rotate_image(msk1, angle, scale, rot_pnt) |
|
msk2 = rotate_image(msk2, angle, scale, rot_pnt) |
|
msk3 = rotate_image(msk3, angle, scale, rot_pnt) |
|
msk4 = rotate_image(msk4, angle, scale, rot_pnt) |
|
|
|
crop_size = input_shape[0] |
|
if random.random() > 0.1: |
|
crop_size = random.randint(int(input_shape[0] / 1.15), int(input_shape[0] / 0.85)) |
|
|
|
bst_x0 = random.randint(0, img.shape[1] - crop_size) |
|
bst_y0 = random.randint(0, img.shape[0] - crop_size) |
|
bst_sc = -1 |
|
try_cnt = random.randint(1, 10) |
|
for i in range(try_cnt): |
|
x0 = random.randint(0, img.shape[1] - crop_size) |
|
y0 = random.randint(0, img.shape[0] - crop_size) |
|
_sc = msk2[y0:y0+crop_size, x0:x0+crop_size].sum() * 5 + msk3[y0:y0+crop_size, x0:x0+crop_size].sum() * 5 + msk4[y0:y0+crop_size, x0:x0+crop_size].sum() * 2 + msk1[y0:y0+crop_size, x0:x0+crop_size].sum() |
|
if _sc > bst_sc: |
|
bst_sc = _sc |
|
bst_x0 = x0 |
|
bst_y0 = y0 |
|
x0 = bst_x0 |
|
y0 = bst_y0 |
|
img = img[y0:y0+crop_size, x0:x0+crop_size, :] |
|
img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :] |
|
msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size] |
|
msk1 = msk1[y0:y0+crop_size, x0:x0+crop_size] |
|
msk2 = msk2[y0:y0+crop_size, x0:x0+crop_size] |
|
msk3 = msk3[y0:y0+crop_size, x0:x0+crop_size] |
|
msk4 = msk4[y0:y0+crop_size, x0:x0+crop_size] |
|
|
|
if crop_size != input_shape[0]: |
|
img = cv2.resize(img, input_shape, interpolation=cv2.INTER_LINEAR) |
|
img2 = cv2.resize(img2, input_shape, interpolation=cv2.INTER_LINEAR) |
|
msk0 = cv2.resize(msk0, input_shape, interpolation=cv2.INTER_LINEAR) |
|
msk1 = cv2.resize(msk1, input_shape, interpolation=cv2.INTER_LINEAR) |
|
msk2 = cv2.resize(msk2, input_shape, interpolation=cv2.INTER_LINEAR) |
|
msk3 = cv2.resize(msk3, input_shape, interpolation=cv2.INTER_LINEAR) |
|
msk4 = cv2.resize(msk4, input_shape, interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
if random.random() > 0.96: |
|
img = shift_channels(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) |
|
elif random.random() > 0.96: |
|
img2 = shift_channels(img2, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) |
|
|
|
if random.random() > 0.96: |
|
img = change_hsv(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) |
|
elif random.random() > 0.96: |
|
img2 = change_hsv(img2, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) |
|
|
|
if random.random() > 0.9: |
|
if random.random() > 0.96: |
|
img = clahe(img) |
|
elif random.random() > 0.96: |
|
img = gauss_noise(img) |
|
elif random.random() > 0.96: |
|
img = cv2.blur(img, (3, 3)) |
|
elif random.random() > 0.9: |
|
if random.random() > 0.96: |
|
img = saturation(img, 0.9 + random.random() * 0.2) |
|
elif random.random() > 0.96: |
|
img = brightness(img, 0.9 + random.random() * 0.2) |
|
elif random.random() > 0.96: |
|
img = contrast(img, 0.9 + random.random() * 0.2) |
|
|
|
if random.random() > 0.9: |
|
if random.random() > 0.96: |
|
img2 = clahe(img2) |
|
elif random.random() > 0.96: |
|
img2 = gauss_noise(img2) |
|
elif random.random() > 0.96: |
|
img2 = cv2.blur(img2, (3, 3)) |
|
elif random.random() > 0.9: |
|
if random.random() > 0.96: |
|
img2 = saturation(img2, 0.9 + random.random() * 0.2) |
|
elif random.random() > 0.96: |
|
img2 = brightness(img2, 0.9 + random.random() * 0.2) |
|
elif random.random() > 0.96: |
|
img2 = contrast(img2, 0.9 + random.random() * 0.2) |
|
|
|
|
|
if random.random() > 0.96: |
|
el_det = self.elastic.to_deterministic() |
|
img = el_det.augment_image(img) |
|
|
|
if random.random() > 0.96: |
|
el_det = self.elastic.to_deterministic() |
|
img2 = el_det.augment_image(img2) |
|
|
|
msk0 = msk0[..., np.newaxis] |
|
msk1 = msk1[..., np.newaxis] |
|
msk2 = msk2[..., np.newaxis] |
|
msk3 = msk3[..., np.newaxis] |
|
msk4 = msk4[..., np.newaxis] |
|
|
|
msk = np.concatenate([msk0, msk1, msk2, msk3, msk4], axis=2) |
|
msk = (msk > 127) |
|
|
|
msk[..., 0] = True |
|
msk[..., 1] = dilation(msk[..., 1], square(5)) |
|
msk[..., 2] = dilation(msk[..., 2], square(5)) |
|
msk[..., 3] = dilation(msk[..., 3], square(5)) |
|
msk[..., 4] = dilation(msk[..., 4], square(5)) |
|
msk[..., 1][msk[..., 2:].max(axis=2)] = False |
|
msk[..., 3][msk[..., 2]] = False |
|
msk[..., 4][msk[..., 2]] = False |
|
msk[..., 4][msk[..., 3]] = False |
|
msk[..., 0][msk[..., 1:].max(axis=2)] = False |
|
msk = msk * 1 |
|
|
|
lbl_msk = msk.argmax(axis=2) |
|
|
|
img = np.concatenate([img, img2], axis=2) |
|
img = preprocess_inputs(img) |
|
|
|
img = torch.from_numpy(img.transpose((2, 0, 1))).float() |
|
msk = torch.from_numpy(msk.transpose((2, 0, 1))).long() |
|
|
|
sample = {'img': img, 'msk': msk, 'lbl_msk': lbl_msk, 'fn': fn} |
|
return sample |
|
|
|
|
|
class ValData(Dataset): |
|
def __init__(self, image_idxs,loc_folder): |
|
super().__init__() |
|
self.image_idxs = image_idxs |
|
self.loc_folder = loc_folder |
|
|
|
def __len__(self): |
|
return len(self.image_idxs) |
|
|
|
def __getitem__(self, idx): |
|
_idx = self.image_idxs[idx] |
|
|
|
fn = all_files[_idx] |
|
|
|
img = cv2.imread(fn, cv2.IMREAD_COLOR) |
|
img2 = cv2.imread(fn.replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_COLOR) |
|
|
|
msk_loc = cv2.imread(path.join(self.loc_folder, '{0}.png'.format(fn.split('/')[-1].replace('.png', '_part1.png'))), cv2.IMREAD_UNCHANGED) > (0.3*255) |
|
|
|
msk0 = cv2.imread(fn.replace('/images/', '/masks/'), cv2.IMREAD_UNCHANGED) |
|
lbl_msk1 = cv2.imread(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_UNCHANGED) |
|
msk1 = np.zeros_like(lbl_msk1) |
|
msk2 = np.zeros_like(lbl_msk1) |
|
msk3 = np.zeros_like(lbl_msk1) |
|
msk4 = np.zeros_like(lbl_msk1) |
|
msk1[lbl_msk1 == 1] = 255 |
|
msk2[lbl_msk1 == 2] = 255 |
|
msk3[lbl_msk1 == 3] = 255 |
|
msk4[lbl_msk1 == 4] = 255 |
|
|
|
msk0 = msk0[..., np.newaxis] |
|
msk1 = msk1[..., np.newaxis] |
|
msk2 = msk2[..., np.newaxis] |
|
msk3 = msk3[..., np.newaxis] |
|
msk4 = msk4[..., np.newaxis] |
|
|
|
msk = np.concatenate([msk0, msk1, msk2, msk3, msk4], axis=2) |
|
msk = (msk > 127) |
|
|
|
msk = msk * 1 |
|
|
|
lbl_msk = msk[..., 1:].argmax(axis=2) |
|
|
|
img = np.concatenate([img, img2], axis=2) |
|
img = preprocess_inputs(img) |
|
|
|
img = torch.from_numpy(img.transpose((2, 0, 1))).float() |
|
msk = torch.from_numpy(msk.transpose((2, 0, 1))).long() |
|
|
|
sample = {'img': img, 'msk': msk, 'lbl_msk': lbl_msk, 'fn': fn, 'msk_loc': msk_loc} |
|
return sample |
|
|
|
|
|
def validate(net, data_loader): |
|
dices0 = [] |
|
|
|
tp = np.zeros((5,)) |
|
fp = np.zeros((5,)) |
|
fn = np.zeros((5,)) |
|
|
|
_thr = 0.3 |
|
|
|
with torch.no_grad(): |
|
for i, sample in enumerate(tqdm(data_loader)): |
|
msks = sample["msk"].numpy() |
|
lbl_msk = sample["lbl_msk"].numpy() |
|
imgs = sample["img"].cuda(non_blocking=True) |
|
msk_loc = sample["msk_loc"].numpy() * 1 |
|
out = local_forward(imgs,(256,256),model) |
|
|
|
msk_pred = msk_loc |
|
msk_damage_pred = torch.softmax(out, dim=1).cpu().numpy()[:, 1:, ...] |
|
|
|
for j in range(msks.shape[0]): |
|
tp[4] += np.logical_and(msks[j, 0] > 0, msk_pred[j] > 0).sum() |
|
fn[4] += np.logical_and(msks[j, 0] < 1, msk_pred[j] > 0).sum() |
|
fp[4] += np.logical_and(msks[j, 0] > 0, msk_pred[j] < 1).sum() |
|
|
|
|
|
targ = lbl_msk[j][msks[j, 0] > 0] |
|
pred = msk_damage_pred[j].argmax(axis=0) |
|
pred = pred * (msk_pred[j] > _thr) |
|
pred = pred[msks[j, 0] > 0] |
|
for c in range(4): |
|
tp[c] += np.logical_and(pred == c, targ == c).sum() |
|
fn[c] += np.logical_and(pred != c, targ == c).sum() |
|
fp[c] += np.logical_and(pred == c, targ != c).sum() |
|
|
|
all_gathered = all_gather(dict(tp=tp,fp=fp,fn=fn)) |
|
tp = np.zeros((5,)) |
|
fp = np.zeros((5,)) |
|
fn = np.zeros((5,)) |
|
for d in all_gathered: |
|
tp += d['tp'] |
|
fp += d['fp'] |
|
fn += d['fn'] |
|
d0 = 2 * tp[4] / (2 * tp[4] + fp[4] + fn[4]) |
|
|
|
f1_sc = np.zeros((4,)) |
|
for c in range(4): |
|
f1_sc[c] = 2 * tp[c] / (2 * tp[c] + fp[c] + fn[c]) |
|
|
|
f1 = 4 / np.sum(1.0 / (f1_sc + 1e-6)) |
|
|
|
sc = 0.3 * d0 + 0.7 * f1 |
|
print("Val Score: {}, Dice: {}, F1: {}, F1_0: {}, F1_1: {}, F1_2: {}, F1_3: {}".format(sc, d0, f1, f1_sc[0], f1_sc[1], f1_sc[2], f1_sc[3])) |
|
return sc |
|
|
|
|
|
def evaluate_val(data_val, best_score, model, snapshot_name, current_epoch,save=False): |
|
model = model.eval() |
|
d = validate(model, data_loader=data_val) |
|
|
|
if save and d > best_score: |
|
torch.save({ |
|
'epoch': current_epoch + 1, |
|
'state_dict': model.state_dict(), |
|
'best_score': d, |
|
}, path.join(models_folder, snapshot_name + '_best')) |
|
best_score = d |
|
|
|
print("score: {}\tscore_best: {}".format(d, best_score)) |
|
return best_score |
|
|
|
def local_forward(x,img_size,model): |
|
assert x.shape[2] % img_size[0] ==0 and x.shape[3] % img_size[1] ==0 |
|
nh = x.shape[2] // img_size[0] |
|
nw = x.shape[3] // img_size[1] |
|
output = torch.zeros(x.shape[0],5,x.shape[2],x.shape[3]).to(x) |
|
for i in range(nh): |
|
for j in range(nw): |
|
local_y = model(x[...,i*img_size[0]:(i+1)*img_size[0],j*img_size[1]:(j+1)*img_size[1]]) |
|
output[...,i*img_size[0]:(i+1)*img_size[0],j*img_size[1]:(j+1)*img_size[1]]=local_y |
|
return output |
|
def train_epoch(current_epoch, seg_loss, ce_loss, model, optimizer, scheduler, train_data_loader,scaler): |
|
losses = AverageMeter() |
|
losses1 = AverageMeter() |
|
|
|
dices = AverageMeter() |
|
|
|
iterator = tqdm(train_data_loader) |
|
model.train() |
|
for i, sample in enumerate(iterator): |
|
imgs = sample["img"].cuda(non_blocking=True) |
|
msks = sample["msk"].cuda(non_blocking=True) |
|
lbl_msk = sample["lbl_msk"].cuda(non_blocking=True) |
|
with torch.amp.autocast(device_type='cuda', dtype=torch.float16): |
|
out = model(imgs) |
|
|
|
loss0 = seg_loss(out[:, 0, ...], msks[:, 0, ...]) |
|
loss1 = seg_loss(out[:, 1, ...], msks[:, 1, ...]) |
|
loss2 = seg_loss(out[:, 2, ...], msks[:, 2, ...]) |
|
loss3 = seg_loss(out[:, 3, ...], msks[:, 3, ...]) |
|
loss4 = seg_loss(out[:, 4, ...], msks[:, 4, ...]) |
|
|
|
loss5 = ce_loss(out, lbl_msk) |
|
|
|
loss = 0.1 * loss0 + 0.1 * loss1 + 0.3 * loss2 + 0.3 * loss3 + 0.2 * loss4 + loss5 * 11 |
|
|
|
with torch.no_grad(): |
|
_probs = 1 - torch.sigmoid(out[:, 0, ...]) |
|
dice_sc = 1 - dice_round(_probs, 1 - msks[:, 0, ...]) |
|
|
|
losses.update(loss.item(), imgs.size(0)) |
|
losses1.update(loss5.item(), imgs.size(0)) |
|
|
|
dices.update(dice_sc, imgs.size(0)) |
|
|
|
iterator.set_description( |
|
"epoch: {}; lr {:.7f}; Loss {loss.val:.4f} ({loss.avg:.4f}); cce_loss {loss1.val:.4f} ({loss1.avg:.4f}); Dice {dice.val:.4f} ({dice.avg:.4f})".format( |
|
current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, dice=dices)) |
|
|
|
optimizer.zero_grad() |
|
|
|
scaler.scale(loss).backward() |
|
scaler.unscale_(optimizer) |
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.1) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
scheduler.step(current_epoch) |
|
|
|
print("epoch: {}; lr {:.7f}; Loss {loss.avg:.4f}; CCE_loss {loss1.avg:.4f}; Dice {dice.avg:.4f}".format( |
|
current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, dice=dices)) |
|
|
|
|
|
from torch.utils.data import DistributedSampler |
|
import torch.distributed as dist |
|
from argparse import ArgumentParser |
|
if __name__ == '__main__': |
|
t0 = timeit.default_timer() |
|
parser = ArgumentParser() |
|
parser.add_argument('seed',type=int) |
|
parser.add_argument('--val',action='store_true') |
|
parser.add_argument('--resume',action='store_true') |
|
parser.add_argument('--model',default='res50') |
|
args = parser.parse_args() |
|
makedirs(models_folder, exist_ok=True) |
|
|
|
seed = args.seed |
|
|
|
|
|
|
|
|
|
local_rank = int(os.environ['LOCAL_RANK']) |
|
cudnn.benchmark = True |
|
|
|
batch_size = 8 |
|
val_batch_size = 4 |
|
dist.init_process_group(backend='nccl') |
|
torch.cuda.set_device(local_rank) |
|
snapshot_name = '{}_cls_cce_{}_0'.format(args.model,seed) |
|
|
|
file_classes = [] |
|
for fn in tqdm(all_files): |
|
fl = np.zeros((4,), dtype=bool) |
|
msk1 = cv2.imread(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_UNCHANGED) |
|
for c in range(1, 5): |
|
fl[c-1] = c in msk1 |
|
file_classes.append(fl) |
|
file_classes = np.asarray(file_classes) |
|
|
|
train_idxs0, val_idxs = train_test_split(np.arange(len(all_files)), test_size=0.1, random_state=seed) |
|
|
|
np.random.seed(seed + 1234) |
|
random.seed(seed + 1234) |
|
if not os.path.isfile('folds.pth'): |
|
torch.save( |
|
dict( |
|
train_idxs =train_idxs0, |
|
val_idxs=val_idxs, |
|
),'folds.pth' |
|
) |
|
else: |
|
folds = torch.load('folds.pth') |
|
train_idxs0 = folds['train_idxs'] |
|
val_idxs = folds['val_idxs'] |
|
|
|
train_idxs = [] |
|
for i in train_idxs0: |
|
train_idxs.append(i) |
|
if file_classes[i, 1:].max(): |
|
train_idxs.append(i) |
|
if file_classes[i, 1:3].max(): |
|
train_idxs.append(i) |
|
train_idxs = np.asarray(train_idxs) |
|
|
|
steps_per_epoch = len(train_idxs) // batch_size |
|
validation_steps = len(val_idxs) // val_batch_size |
|
|
|
print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps) |
|
|
|
data_train = TrainData(train_idxs) |
|
val_train = ValData(val_idxs,loc_folder = f'pred_loc_val_{args.model}') |
|
sampler_train = DistributedSampler(data_train,shuffle=True) |
|
sampler_val = DistributedSampler(val_train,shuffle=False) |
|
train_data_loader = DataLoader(data_train,sampler=sampler_train, batch_size=batch_size, num_workers=6, pin_memory=False, drop_last=True) |
|
val_data_loader = DataLoader(val_train,sampler=sampler_val, batch_size=val_batch_size, num_workers=6, pin_memory=False) |
|
|
|
model = SeResNext50_Unet_Double().cuda() |
|
|
|
params = model.parameters() |
|
|
|
optimizer = AdamW(params, lr=0.00001, weight_decay=1e-6) |
|
|
|
|
|
|
|
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[5, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190], gamma=0.5) |
|
|
|
if args.resume: |
|
snap_to_load = snapshot_name+'_best' |
|
else: |
|
snap_to_load = '{}_loc_{}_0_best'.format(args.model,seed) |
|
print("=> loading checkpoint '{}'".format(snap_to_load)) |
|
checkpoint = torch.load(path.join(models_folder, snap_to_load), map_location='cpu') |
|
loaded_dict = checkpoint['state_dict'] |
|
if 'module.' in list(loaded_dict.keys())[0]: |
|
loaded_dict = {k[7:]:v for k,v in loaded_dict.items()} |
|
sd = model.state_dict() |
|
for k in model.state_dict(): |
|
if k in loaded_dict and sd[k].size() == loaded_dict[k].size(): |
|
sd[k] = loaded_dict[k] |
|
else: |
|
print("SKIPPED:",k) |
|
loaded_dict = sd |
|
model.load_state_dict(loaded_dict) |
|
print("loaded checkpoint '{}' (epoch {}, best_score {})" |
|
.format(snap_to_load, checkpoint['epoch'], checkpoint['best_score'])) |
|
del loaded_dict |
|
del sd |
|
del checkpoint |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
model = DistributedDataParallel(model) |
|
scaler = torch.cuda.amp.GradScaler() |
|
|
|
seg_loss = ComboLoss({'dice': 0.5, 'focal': 2.0}, per_image=False).cuda() |
|
ce_loss = nn.CrossEntropyLoss().cuda() |
|
|
|
best_score = 0 |
|
torch.cuda.empty_cache() |
|
if args.val: |
|
best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, -1) |
|
else: |
|
for epoch in range(10): |
|
train_epoch(epoch, seg_loss, ce_loss, model, optimizer, scheduler, train_data_loader,scaler) |
|
if epoch % 2 == 0: |
|
torch.cuda.empty_cache() |
|
best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch,True) |
|
|
|
elapsed = timeit.default_timer() - t0 |
|
print('Time: {:.3f} min'.format(elapsed / 60)) |