|
import os |
|
import time |
|
import json |
|
import datetime as datetime |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torch.distributed as dist |
|
from torch.utils.data import DataLoader |
|
from torchvision import transforms |
|
|
|
from dataloaders.train_datasets import DAVIS2017_Train, YOUTUBEVOS_Train, StaticTrain, TEST |
|
import dataloaders.video_transforms as tr |
|
|
|
from utils.meters import AverageMeter |
|
from utils.image import label2colormap, masked_image, save_image |
|
from utils.checkpoint import load_network_and_optimizer, load_network, save_network |
|
from utils.learning import adjust_learning_rate, get_trainable_params |
|
from utils.metric import pytorch_iou |
|
from utils.ema import ExponentialMovingAverage, get_param_buffer_for_ema |
|
|
|
from networks.models import build_vos_model |
|
from networks.engines import build_engine |
|
|
|
|
|
class Trainer(object): |
|
def __init__(self, rank, cfg, enable_amp=True): |
|
self.gpu = rank + cfg.DIST_START_GPU |
|
self.gpu_num = cfg.TRAIN_GPUS |
|
self.rank = rank |
|
self.cfg = cfg |
|
|
|
self.print_log("Exp {}:".format(cfg.EXP_NAME)) |
|
self.print_log(json.dumps(cfg.__dict__, indent=4, sort_keys=True)) |
|
|
|
print("Use GPU {} for training VOS.".format(self.gpu)) |
|
torch.cuda.set_device(self.gpu) |
|
torch.backends.cudnn.benchmark = True if cfg.DATA_RANDOMCROP[ |
|
0] == cfg.DATA_RANDOMCROP[ |
|
1] and 'swin' not in cfg.MODEL_ENCODER else False |
|
|
|
self.print_log('Build VOS model.') |
|
|
|
self.model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(self.gpu) |
|
self.model_encoder = self.model.encoder |
|
self.engine = build_engine( |
|
cfg.MODEL_ENGINE, |
|
'train', |
|
aot_model=self.model, |
|
gpu_id=self.gpu, |
|
long_term_mem_gap=cfg.TRAIN_LONG_TERM_MEM_GAP) |
|
|
|
if cfg.MODEL_FREEZE_BACKBONE: |
|
for param in self.model_encoder.parameters(): |
|
param.requires_grad = False |
|
|
|
if cfg.DIST_ENABLE: |
|
dist.init_process_group(backend=cfg.DIST_BACKEND, |
|
init_method=cfg.DIST_URL, |
|
world_size=cfg.TRAIN_GPUS, |
|
rank=rank, |
|
timeout=datetime.timedelta(seconds=300)) |
|
|
|
self.model.encoder = nn.SyncBatchNorm.convert_sync_batchnorm( |
|
self.model.encoder).cuda(self.gpu) |
|
|
|
self.dist_engine = torch.nn.parallel.DistributedDataParallel( |
|
self.engine, |
|
device_ids=[self.gpu], |
|
output_device=self.gpu, |
|
find_unused_parameters=True, |
|
broadcast_buffers=False) |
|
else: |
|
self.dist_engine = self.engine |
|
|
|
self.use_frozen_bn = False |
|
if 'swin' in cfg.MODEL_ENCODER: |
|
self.print_log('Use LN in Encoder!') |
|
elif not cfg.MODEL_FREEZE_BN: |
|
if cfg.DIST_ENABLE: |
|
self.print_log('Use Sync BN in Encoder!') |
|
else: |
|
self.print_log('Use BN in Encoder!') |
|
else: |
|
self.use_frozen_bn = True |
|
self.print_log('Use Frozen BN in Encoder!') |
|
|
|
if self.rank == 0: |
|
try: |
|
total_steps = float(cfg.TRAIN_TOTAL_STEPS) |
|
ema_decay = 1. - 1. / (total_steps * cfg.TRAIN_EMA_RATIO) |
|
self.ema_params = get_param_buffer_for_ema( |
|
self.model, update_buffer=(not cfg.MODEL_FREEZE_BN)) |
|
self.ema = ExponentialMovingAverage(self.ema_params, |
|
decay=ema_decay) |
|
self.ema_dir = cfg.DIR_EMA_CKPT |
|
except Exception as inst: |
|
self.print_log(inst) |
|
self.print_log('Error: failed to create EMA model!') |
|
|
|
self.print_log('Build optimizer.') |
|
|
|
trainable_params = get_trainable_params( |
|
model=self.dist_engine, |
|
base_lr=cfg.TRAIN_LR, |
|
use_frozen_bn=self.use_frozen_bn, |
|
weight_decay=cfg.TRAIN_WEIGHT_DECAY, |
|
exclusive_wd_dict=cfg.TRAIN_WEIGHT_DECAY_EXCLUSIVE, |
|
no_wd_keys=cfg.TRAIN_WEIGHT_DECAY_EXEMPTION) |
|
|
|
if cfg.TRAIN_OPT == 'sgd': |
|
self.optimizer = optim.SGD(trainable_params, |
|
lr=cfg.TRAIN_LR, |
|
momentum=cfg.TRAIN_SGD_MOMENTUM, |
|
nesterov=True) |
|
else: |
|
self.optimizer = optim.AdamW(trainable_params, |
|
lr=cfg.TRAIN_LR, |
|
weight_decay=cfg.TRAIN_WEIGHT_DECAY) |
|
|
|
self.enable_amp = enable_amp |
|
if enable_amp: |
|
self.scaler = torch.cuda.amp.GradScaler() |
|
else: |
|
self.scaler = None |
|
|
|
self.prepare_dataset() |
|
self.process_pretrained_model() |
|
|
|
if cfg.TRAIN_TBLOG and self.rank == 0: |
|
from tensorboardX import SummaryWriter |
|
self.tblogger = SummaryWriter(cfg.DIR_TB_LOG) |
|
|
|
def process_pretrained_model(self): |
|
cfg = self.cfg |
|
|
|
self.step = cfg.TRAIN_START_STEP |
|
self.epoch = 0 |
|
|
|
if cfg.TRAIN_AUTO_RESUME: |
|
ckpts = os.listdir(cfg.DIR_CKPT) |
|
if len(ckpts) > 0: |
|
ckpts = list( |
|
map(lambda x: int(x.split('_')[-1].split('.')[0]), ckpts)) |
|
ckpt = np.sort(ckpts)[-1] |
|
cfg.TRAIN_RESUME = True |
|
cfg.TRAIN_RESUME_CKPT = ckpt |
|
cfg.TRAIN_RESUME_STEP = ckpt |
|
else: |
|
cfg.TRAIN_RESUME = False |
|
|
|
if cfg.TRAIN_RESUME: |
|
if self.rank == 0: |
|
try: |
|
try: |
|
ema_ckpt_dir = os.path.join( |
|
self.ema_dir, |
|
'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) |
|
ema_model, removed_dict = load_network( |
|
self.model, ema_ckpt_dir, self.gpu) |
|
except Exception as inst: |
|
self.print_log(inst) |
|
self.print_log('Try to use backup EMA checkpoint.') |
|
DIR_RESULT = './backup/{}/{}'.format( |
|
cfg.EXP_NAME, cfg.STAGE_NAME) |
|
DIR_EMA_CKPT = os.path.join(DIR_RESULT, 'ema_ckpt') |
|
ema_ckpt_dir = os.path.join( |
|
DIR_EMA_CKPT, |
|
'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) |
|
ema_model, removed_dict = load_network( |
|
self.model, ema_ckpt_dir, self.gpu) |
|
|
|
if len(removed_dict) > 0: |
|
self.print_log( |
|
'Remove {} from EMA model.'.format(removed_dict)) |
|
ema_decay = self.ema.decay |
|
del (self.ema) |
|
|
|
ema_params = get_param_buffer_for_ema( |
|
ema_model, update_buffer=(not cfg.MODEL_FREEZE_BN)) |
|
self.ema = ExponentialMovingAverage(ema_params, |
|
decay=ema_decay) |
|
self.ema.num_updates = cfg.TRAIN_RESUME_CKPT |
|
except Exception as inst: |
|
self.print_log(inst) |
|
self.print_log('Error: EMA model not found!') |
|
|
|
try: |
|
resume_ckpt = os.path.join( |
|
cfg.DIR_CKPT, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) |
|
self.model, self.optimizer, removed_dict = load_network_and_optimizer( |
|
self.model, |
|
self.optimizer, |
|
resume_ckpt, |
|
self.gpu, |
|
scaler=self.scaler) |
|
except Exception as inst: |
|
self.print_log(inst) |
|
self.print_log('Try to use backup checkpoint.') |
|
DIR_RESULT = './backup/{}/{}'.format(cfg.EXP_NAME, |
|
cfg.STAGE_NAME) |
|
DIR_CKPT = os.path.join(DIR_RESULT, 'ckpt') |
|
resume_ckpt = os.path.join( |
|
DIR_CKPT, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) |
|
self.model, self.optimizer, removed_dict = load_network_and_optimizer( |
|
self.model, |
|
self.optimizer, |
|
resume_ckpt, |
|
self.gpu, |
|
scaler=self.scaler) |
|
|
|
if len(removed_dict) > 0: |
|
self.print_log( |
|
'Remove {} from checkpoint.'.format(removed_dict)) |
|
|
|
self.step = cfg.TRAIN_RESUME_STEP |
|
if cfg.TRAIN_TOTAL_STEPS <= self.step: |
|
self.print_log("Your training has finished!") |
|
exit() |
|
self.epoch = int(np.ceil(self.step / len(self.train_loader))) |
|
|
|
self.print_log('Resume from step {}'.format(self.step)) |
|
|
|
elif cfg.PRETRAIN: |
|
if cfg.PRETRAIN_FULL: |
|
try: |
|
self.model, removed_dict = load_network( |
|
self.model, cfg.PRETRAIN_MODEL, self.gpu) |
|
except Exception as inst: |
|
self.print_log(inst) |
|
self.print_log('Try to use backup EMA checkpoint.') |
|
DIR_RESULT = './backup/{}/{}'.format( |
|
cfg.EXP_NAME, cfg.STAGE_NAME) |
|
DIR_EMA_CKPT = os.path.join(DIR_RESULT, 'ema_ckpt') |
|
PRETRAIN_MODEL = os.path.join( |
|
DIR_EMA_CKPT, |
|
cfg.PRETRAIN_MODEL.split('/')[-1]) |
|
self.model, removed_dict = load_network( |
|
self.model, PRETRAIN_MODEL, self.gpu) |
|
|
|
if len(removed_dict) > 0: |
|
self.print_log('Remove {} from pretrained model.'.format( |
|
removed_dict)) |
|
self.print_log('Load pretrained VOS model from {}.'.format( |
|
cfg.PRETRAIN_MODEL)) |
|
else: |
|
model_encoder, removed_dict = load_network( |
|
self.model_encoder, cfg.PRETRAIN_MODEL, self.gpu) |
|
if len(removed_dict) > 0: |
|
self.print_log('Remove {} from pretrained model.'.format( |
|
removed_dict)) |
|
self.print_log( |
|
'Load pretrained backbone model from {}.'.format( |
|
cfg.PRETRAIN_MODEL)) |
|
|
|
def prepare_dataset(self): |
|
cfg = self.cfg |
|
self.enable_prev_frame = cfg.TRAIN_ENABLE_PREV_FRAME |
|
|
|
self.print_log('Process dataset...') |
|
if cfg.TRAIN_AUG_TYPE == 'v1': |
|
composed_transforms = transforms.Compose([ |
|
tr.RandomScale(cfg.DATA_MIN_SCALE_FACTOR, |
|
cfg.DATA_MAX_SCALE_FACTOR, |
|
cfg.DATA_SHORT_EDGE_LEN), |
|
tr.BalancedRandomCrop(cfg.DATA_RANDOMCROP, |
|
max_obj_num=cfg.MODEL_MAX_OBJ_NUM), |
|
tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), |
|
tr.Resize(cfg.DATA_RANDOMCROP, use_padding=True), |
|
tr.ToTensor() |
|
]) |
|
elif cfg.TRAIN_AUG_TYPE == 'v2': |
|
composed_transforms = transforms.Compose([ |
|
tr.RandomScale(cfg.DATA_MIN_SCALE_FACTOR, |
|
cfg.DATA_MAX_SCALE_FACTOR, |
|
cfg.DATA_SHORT_EDGE_LEN), |
|
tr.BalancedRandomCrop(cfg.DATA_RANDOMCROP, |
|
max_obj_num=cfg.MODEL_MAX_OBJ_NUM), |
|
tr.RandomColorJitter(), |
|
tr.RandomGrayScale(), |
|
tr.RandomGaussianBlur(), |
|
tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), |
|
tr.Resize(cfg.DATA_RANDOMCROP, use_padding=True), |
|
tr.ToTensor() |
|
]) |
|
else: |
|
assert NotImplementedError |
|
|
|
train_datasets = [] |
|
if 'static' in cfg.DATASETS: |
|
pretrain_vos_dataset = StaticTrain( |
|
cfg.DIR_STATIC, |
|
cfg.DATA_RANDOMCROP, |
|
seq_len=cfg.DATA_SEQ_LEN, |
|
merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB, |
|
max_obj_n=cfg.MODEL_MAX_OBJ_NUM, |
|
aug_type=cfg.TRAIN_AUG_TYPE) |
|
train_datasets.append(pretrain_vos_dataset) |
|
self.enable_prev_frame = False |
|
|
|
if 'davis2017' in cfg.DATASETS: |
|
train_davis_dataset = DAVIS2017_Train( |
|
root=cfg.DIR_DAVIS, |
|
full_resolution=cfg.TRAIN_DATASET_FULL_RESOLUTION, |
|
transform=composed_transforms, |
|
repeat_time=cfg.DATA_DAVIS_REPEAT, |
|
seq_len=cfg.DATA_SEQ_LEN, |
|
rand_gap=cfg.DATA_RANDOM_GAP_DAVIS, |
|
rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ, |
|
merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB, |
|
enable_prev_frame=self.enable_prev_frame, |
|
max_obj_n=cfg.MODEL_MAX_OBJ_NUM) |
|
train_datasets.append(train_davis_dataset) |
|
|
|
if 'youtubevos' in cfg.DATASETS: |
|
train_ytb_dataset = YOUTUBEVOS_Train( |
|
root=cfg.DIR_YTB, |
|
transform=composed_transforms, |
|
seq_len=cfg.DATA_SEQ_LEN, |
|
rand_gap=cfg.DATA_RANDOM_GAP_YTB, |
|
rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ, |
|
merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB, |
|
enable_prev_frame=self.enable_prev_frame, |
|
max_obj_n=cfg.MODEL_MAX_OBJ_NUM) |
|
train_datasets.append(train_ytb_dataset) |
|
|
|
if 'test' in cfg.DATASETS: |
|
test_dataset = TEST(transform=composed_transforms, |
|
seq_len=cfg.DATA_SEQ_LEN) |
|
train_datasets.append(test_dataset) |
|
|
|
if len(train_datasets) > 1: |
|
train_dataset = torch.utils.data.ConcatDataset(train_datasets) |
|
elif len(train_datasets) == 1: |
|
train_dataset = train_datasets[0] |
|
else: |
|
self.print_log('No dataset!') |
|
exit(0) |
|
|
|
self.train_sampler = torch.utils.data.distributed.DistributedSampler( |
|
train_dataset) if self.cfg.DIST_ENABLE else None |
|
self.train_loader = DataLoader(train_dataset, |
|
batch_size=int(cfg.TRAIN_BATCH_SIZE / |
|
cfg.TRAIN_GPUS), |
|
shuffle=False if self.cfg.DIST_ENABLE else True, |
|
num_workers=cfg.DATA_WORKERS, |
|
pin_memory=True, |
|
sampler=self.train_sampler, |
|
drop_last=True, |
|
prefetch_factor=4) |
|
|
|
self.print_log('Done!') |
|
|
|
def sequential_training(self): |
|
|
|
cfg = self.cfg |
|
|
|
if self.enable_prev_frame: |
|
frame_names = ['Ref', 'Prev'] |
|
else: |
|
frame_names = ['Ref(Prev)'] |
|
|
|
for i in range(cfg.DATA_SEQ_LEN - 1): |
|
frame_names.append('Curr{}'.format(i + 1)) |
|
|
|
seq_len = len(frame_names) |
|
|
|
running_losses = [] |
|
running_ious = [] |
|
for _ in range(seq_len): |
|
running_losses.append(AverageMeter()) |
|
running_ious.append(AverageMeter()) |
|
batch_time = AverageMeter() |
|
avg_obj = AverageMeter() |
|
|
|
optimizer = self.optimizer |
|
model = self.dist_engine |
|
train_sampler = self.train_sampler |
|
train_loader = self.train_loader |
|
step = self.step |
|
epoch = self.epoch |
|
max_itr = cfg.TRAIN_TOTAL_STEPS |
|
start_seq_training_step = int(cfg.TRAIN_SEQ_TRAINING_START_RATIO * |
|
max_itr) |
|
use_prev_prob = cfg.MODEL_USE_PREV_PROB |
|
|
|
self.print_log('Start training:') |
|
model.train() |
|
while step < cfg.TRAIN_TOTAL_STEPS: |
|
if self.cfg.DIST_ENABLE: |
|
train_sampler.set_epoch(epoch) |
|
epoch += 1 |
|
last_time = time.time() |
|
for frame_idx, sample in enumerate(train_loader): |
|
if step > cfg.TRAIN_TOTAL_STEPS: |
|
break |
|
|
|
if step % cfg.TRAIN_TBLOG_STEP == 0 and self.rank == 0 and cfg.TRAIN_TBLOG: |
|
tf_board = True |
|
else: |
|
tf_board = False |
|
|
|
if step >= start_seq_training_step: |
|
use_prev_pred = True |
|
freeze_params = cfg.TRAIN_SEQ_TRAINING_FREEZE_PARAMS |
|
else: |
|
use_prev_pred = False |
|
freeze_params = [] |
|
|
|
if step % cfg.TRAIN_LR_UPDATE_STEP == 0: |
|
now_lr = adjust_learning_rate( |
|
optimizer=optimizer, |
|
base_lr=cfg.TRAIN_LR, |
|
p=cfg.TRAIN_LR_POWER, |
|
itr=step, |
|
max_itr=max_itr, |
|
restart=cfg.TRAIN_LR_RESTART, |
|
warm_up_steps=cfg.TRAIN_LR_WARM_UP_RATIO * max_itr, |
|
is_cosine_decay=cfg.TRAIN_LR_COSINE_DECAY, |
|
min_lr=cfg.TRAIN_LR_MIN, |
|
encoder_lr_ratio=cfg.TRAIN_LR_ENCODER_RATIO, |
|
freeze_params=freeze_params) |
|
|
|
ref_imgs = sample['ref_img'] |
|
prev_imgs = sample['prev_img'] |
|
curr_imgs = sample['curr_img'] |
|
ref_labels = sample['ref_label'] |
|
prev_labels = sample['prev_label'] |
|
curr_labels = sample['curr_label'] |
|
obj_nums = sample['meta']['obj_num'] |
|
bs, _, h, w = curr_imgs[0].size() |
|
|
|
ref_imgs = ref_imgs.cuda(self.gpu, non_blocking=True) |
|
prev_imgs = prev_imgs.cuda(self.gpu, non_blocking=True) |
|
curr_imgs = [ |
|
curr_img.cuda(self.gpu, non_blocking=True) |
|
for curr_img in curr_imgs |
|
] |
|
ref_labels = ref_labels.cuda(self.gpu, non_blocking=True) |
|
prev_labels = prev_labels.cuda(self.gpu, non_blocking=True) |
|
curr_labels = [ |
|
curr_label.cuda(self.gpu, non_blocking=True) |
|
for curr_label in curr_labels |
|
] |
|
obj_nums = list(obj_nums) |
|
obj_nums = [int(obj_num) for obj_num in obj_nums] |
|
|
|
batch_size = ref_imgs.size(0) |
|
|
|
all_frames = torch.cat([ref_imgs, prev_imgs] + curr_imgs, |
|
dim=0) |
|
all_labels = torch.cat([ref_labels, prev_labels] + curr_labels, |
|
dim=0) |
|
|
|
self.engine.restart_engine(batch_size, True) |
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
if self.enable_amp: |
|
with torch.cuda.amp.autocast(enabled=True): |
|
|
|
loss, all_pred, all_loss, boards = model( |
|
all_frames, |
|
all_labels, |
|
batch_size, |
|
use_prev_pred=use_prev_pred, |
|
obj_nums=obj_nums, |
|
step=step, |
|
tf_board=tf_board, |
|
enable_prev_frame=self.enable_prev_frame, |
|
use_prev_prob=use_prev_prob) |
|
loss = torch.mean(loss) |
|
|
|
start = time.time() |
|
self.scaler.scale(loss).backward() |
|
end = time.time() |
|
print(end-start) |
|
self.scaler.unscale_(optimizer) |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), |
|
cfg.TRAIN_CLIP_GRAD_NORM) |
|
self.scaler.step(optimizer) |
|
self.scaler.update() |
|
|
|
else: |
|
loss, all_pred, all_loss, boards = model( |
|
all_frames, |
|
all_labels, |
|
ref_imgs.size(0), |
|
use_prev_pred=use_prev_pred, |
|
obj_nums=obj_nums, |
|
step=step, |
|
tf_board=tf_board, |
|
enable_prev_frame=self.enable_prev_frame, |
|
use_prev_prob=use_prev_prob) |
|
loss = torch.mean(loss) |
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), |
|
cfg.TRAIN_CLIP_GRAD_NORM) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
for idx in range(seq_len): |
|
now_pred = all_pred[idx].detach() |
|
now_label = all_labels[idx * bs:(idx + 1) * bs].detach() |
|
now_loss = torch.mean(all_loss[idx].detach()) |
|
now_iou = pytorch_iou(now_pred.unsqueeze(1), now_label, |
|
obj_nums) * 100 |
|
if self.cfg.DIST_ENABLE: |
|
dist.all_reduce(now_loss) |
|
dist.all_reduce(now_iou) |
|
now_loss /= self.gpu_num |
|
now_iou /= self.gpu_num |
|
if self.rank == 0: |
|
running_losses[idx].update(now_loss.item()) |
|
running_ious[idx].update(now_iou.item()) |
|
|
|
if self.rank == 0: |
|
self.ema.update(self.ema_params) |
|
|
|
avg_obj.update(sum(obj_nums) / float(len(obj_nums))) |
|
curr_time = time.time() |
|
batch_time.update(curr_time - last_time) |
|
last_time = curr_time |
|
|
|
if step % cfg.TRAIN_TBLOG_STEP == 0: |
|
all_f = [ref_imgs, prev_imgs] + curr_imgs |
|
self.process_log(ref_imgs, all_f[-2], all_f[-1], |
|
ref_labels, all_pred[-2], now_label, |
|
now_pred, boards, running_losses, |
|
running_ious, now_lr, step) |
|
|
|
if step % cfg.TRAIN_LOG_STEP == 0: |
|
strs = 'I:{}, LR:{:.5f}, T:{:.1f}({:.1f})s, Obj:{:.1f}({:.1f})'.format( |
|
step, now_lr, batch_time.val, |
|
batch_time.moving_avg, avg_obj.val, |
|
avg_obj.moving_avg) |
|
batch_time.reset() |
|
avg_obj.reset() |
|
for idx in range(seq_len): |
|
strs += ', {}: L {:.3f}({:.3f}) IoU {:.1f}({:.1f})%'.format( |
|
frame_names[idx], running_losses[idx].val, |
|
running_losses[idx].moving_avg, |
|
running_ious[idx].val, |
|
running_ious[idx].moving_avg) |
|
running_losses[idx].reset() |
|
running_ious[idx].reset() |
|
|
|
self.print_log(strs) |
|
|
|
step += 1 |
|
|
|
if step % cfg.TRAIN_SAVE_STEP == 0 and self.rank == 0: |
|
max_mem = torch.cuda.max_memory_allocated( |
|
device=self.gpu) / (1024.**3) |
|
ETA = str( |
|
datetime.timedelta( |
|
seconds=int(batch_time.moving_avg * |
|
(cfg.TRAIN_TOTAL_STEPS - step)))) |
|
self.print_log('ETA: {}, Max Mem: {:.2f}G.'.format( |
|
ETA, max_mem)) |
|
self.print_log('Save CKPT (Step {}).'.format(step)) |
|
save_network(self.model, |
|
optimizer, |
|
step, |
|
cfg.DIR_CKPT, |
|
cfg.TRAIN_MAX_KEEP_CKPT, |
|
backup_dir='./backup/{}/{}/ckpt'.format( |
|
cfg.EXP_NAME, cfg.STAGE_NAME), |
|
scaler=self.scaler) |
|
try: |
|
torch.cuda.empty_cache() |
|
|
|
self.ema.store(self.ema_params) |
|
|
|
self.ema.copy_to(self.ema_params) |
|
|
|
save_network( |
|
self.model, |
|
optimizer, |
|
step, |
|
self.ema_dir, |
|
cfg.TRAIN_MAX_KEEP_CKPT, |
|
backup_dir='./backup/{}/{}/ema_ckpt'.format( |
|
cfg.EXP_NAME, cfg.STAGE_NAME), |
|
scaler=self.scaler) |
|
|
|
self.ema.restore(self.ema_params) |
|
except Exception as inst: |
|
self.print_log(inst) |
|
self.print_log('Error: failed to save EMA model!') |
|
|
|
self.print_log('Stop training!') |
|
|
|
def print_log(self, string): |
|
if self.rank == 0: |
|
print(string) |
|
|
|
def process_log(self, ref_imgs, prev_imgs, curr_imgs, ref_labels, |
|
prev_labels, curr_labels, curr_pred, boards, |
|
running_losses, running_ious, now_lr, step): |
|
cfg = self.cfg |
|
|
|
mean = np.array([[[0.485]], [[0.456]], [[0.406]]]) |
|
sigma = np.array([[[0.229]], [[0.224]], [[0.225]]]) |
|
|
|
show_ref_img, show_prev_img, show_curr_img = [ |
|
img.cpu().numpy()[0] * sigma + mean |
|
for img in [ref_imgs, prev_imgs, curr_imgs] |
|
] |
|
|
|
show_gt, show_prev_gt, show_ref_gt, show_preds_s = [ |
|
label.cpu()[0].squeeze(0).numpy() |
|
for label in [curr_labels, prev_labels, ref_labels, curr_pred] |
|
] |
|
|
|
show_gtf, show_prev_gtf, show_ref_gtf, show_preds_sf = [ |
|
label2colormap(label).transpose((2, 0, 1)) |
|
for label in [show_gt, show_prev_gt, show_ref_gt, show_preds_s] |
|
] |
|
|
|
if cfg.TRAIN_IMG_LOG or cfg.TRAIN_TBLOG: |
|
|
|
show_ref_img = masked_image(show_ref_img, show_ref_gtf, |
|
show_ref_gt) |
|
if cfg.TRAIN_IMG_LOG: |
|
save_image( |
|
show_ref_img, |
|
os.path.join(cfg.DIR_IMG_LOG, |
|
'%06d_ref_img.jpeg' % (step))) |
|
|
|
show_prev_img = masked_image(show_prev_img, show_prev_gtf, |
|
show_prev_gt) |
|
if cfg.TRAIN_IMG_LOG: |
|
save_image( |
|
show_prev_img, |
|
os.path.join(cfg.DIR_IMG_LOG, |
|
'%06d_prev_img.jpeg' % (step))) |
|
|
|
show_img_pred = masked_image(show_curr_img, show_preds_sf, |
|
show_preds_s) |
|
if cfg.TRAIN_IMG_LOG: |
|
save_image( |
|
show_img_pred, |
|
os.path.join(cfg.DIR_IMG_LOG, |
|
'%06d_prediction.jpeg' % (step))) |
|
|
|
show_curr_img = masked_image(show_curr_img, show_gtf, show_gt) |
|
if cfg.TRAIN_IMG_LOG: |
|
save_image( |
|
show_curr_img, |
|
os.path.join(cfg.DIR_IMG_LOG, |
|
'%06d_groundtruth.jpeg' % (step))) |
|
|
|
if cfg.TRAIN_TBLOG: |
|
for seq_step, running_loss, running_iou in zip( |
|
range(len(running_losses)), running_losses, |
|
running_ious): |
|
self.tblogger.add_scalar('S{}/Loss'.format(seq_step), |
|
running_loss.avg, step) |
|
self.tblogger.add_scalar('S{}/IoU'.format(seq_step), |
|
running_iou.avg, step) |
|
|
|
self.tblogger.add_scalar('LR', now_lr, step) |
|
self.tblogger.add_image('Ref/Image', show_ref_img, step) |
|
self.tblogger.add_image('Ref/GT', show_ref_gtf, step) |
|
|
|
self.tblogger.add_image('Prev/Image', show_prev_img, step) |
|
self.tblogger.add_image('Prev/GT', show_prev_gtf, step) |
|
|
|
self.tblogger.add_image('Curr/Image_GT', show_curr_img, step) |
|
self.tblogger.add_image('Curr/Image_Pred', show_img_pred, step) |
|
|
|
self.tblogger.add_image('Curr/Mask_GT', show_gtf, step) |
|
self.tblogger.add_image('Curr/Mask_Pred', show_preds_sf, step) |
|
|
|
for key in boards['image'].keys(): |
|
tmp = boards['image'][key].cpu().numpy() |
|
self.tblogger.add_image('S{}/' + key, tmp, step) |
|
for key in boards['scalar'].keys(): |
|
tmp = boards['scalar'][key].cpu().numpy() |
|
self.tblogger.add_scalar('S{}/' + key, tmp, step) |
|
|
|
self.tblogger.flush() |
|
|
|
del (boards) |
|
|