|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import math |
|
import os |
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
def __init__(self, name, fmt=':f'): |
|
self.name = name |
|
self.fmt = fmt |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
def __str__(self): |
|
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
|
return fmtstr.format(**self.__dict__) |
|
|
|
|
|
class ProgressMeter(object): |
|
def __init__(self, num_batches, meters, prefix=""): |
|
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
|
self.meters = meters |
|
self.prefix = prefix |
|
|
|
def display(self, batch): |
|
entries = [self.prefix + self.batch_fmtstr.format(batch)] |
|
entries += [str(meter) for meter in self.meters] |
|
print('\t'.join(entries)) |
|
|
|
def _get_batch_fmtstr(self, num_batches): |
|
num_digits = len(str(num_batches // 1)) |
|
fmt = '{:' + str(num_digits) + 'd}' |
|
return '[' + fmt + '/' + fmt.format(num_batches) + ']' |
|
|
|
|
|
def accuracy(output, target, topk=(1,)): |
|
"""Computes the accuracy over the k top predictions for the specified values of k""" |
|
with torch.no_grad(): |
|
maxk = max(topk) |
|
batch_size = target.size(0) |
|
|
|
_, pred = output.topk(maxk, 1, True, True) |
|
pred = pred.t() |
|
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
|
|
|
res = [] |
|
for k in topk: |
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) |
|
res.append(correct_k.mul_(100.0 / batch_size)) |
|
return res |
|
|
|
def load_checkpoint(model, ckpt_path): |
|
checkpoint = torch.load(ckpt_path) |
|
if 'model' in checkpoint: |
|
checkpoint = checkpoint['model'] |
|
if 'state_dict' in checkpoint: |
|
checkpoint = checkpoint['state_dict'] |
|
ckpt = {} |
|
for k, v in checkpoint.items(): |
|
if k.startswith('module.'): |
|
ckpt[k[7:]] = v |
|
else: |
|
ckpt[k] = v |
|
model.load_state_dict(ckpt) |
|
|
|
|
|
class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler): |
|
|
|
def __init__(self, optimizer, T_cosine_max, eta_min=0, last_epoch=-1, warmup=0): |
|
self.eta_min = eta_min |
|
self.T_cosine_max = T_cosine_max |
|
self.warmup = warmup |
|
super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) |
|
|
|
def get_lr(self): |
|
if self.last_epoch < self.warmup: |
|
return [self.last_epoch / self.warmup * base_lr for base_lr in self.base_lrs] |
|
else: |
|
return [self.eta_min + (base_lr - self.eta_min) * |
|
(1 + math.cos(math.pi * (self.last_epoch - self.warmup) / (self.T_cosine_max - self.warmup))) / 2 |
|
for base_lr in self.base_lrs] |
|
|
|
|
|
def log_msg(message, log_file): |
|
print(message) |
|
with open(log_file, 'a') as f: |
|
print(message, file=f) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
from apex import amp |
|
except ImportError: |
|
amp = None |
|
|
|
def unwrap_model(model): |
|
"""Remove the DistributedDataParallel wrapper if present.""" |
|
wrapped = isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel) |
|
return model.module if wrapped else model |
|
|
|
|
|
def load_checkpoint(config, model, optimizer, lr_scheduler, logger, model_ema=None): |
|
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") |
|
if config.MODEL.RESUME.startswith('https'): |
|
checkpoint = torch.hub.load_state_dict_from_url( |
|
config.MODEL.RESUME, map_location='cpu', check_hash=True) |
|
else: |
|
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') |
|
msg = model.load_state_dict(checkpoint['model'], strict=False) |
|
logger.info(msg) |
|
max_accuracy = 0.0 |
|
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
config.defrost() |
|
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 |
|
config.freeze() |
|
if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0": |
|
amp.load_state_dict(checkpoint['amp']) |
|
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") |
|
if 'max_accuracy' in checkpoint: |
|
max_accuracy = checkpoint['max_accuracy'] |
|
if model_ema is not None: |
|
unwrap_model(model_ema).load_state_dict(checkpoint['ema']) |
|
print('=================================================== EMAloaded') |
|
|
|
del checkpoint |
|
torch.cuda.empty_cache() |
|
return max_accuracy |
|
|
|
|
|
def load_weights(model, path): |
|
checkpoint = torch.load(path, map_location='cpu') |
|
if 'model' in checkpoint: |
|
checkpoint = checkpoint['model'] |
|
if 'state_dict' in checkpoint: |
|
checkpoint = checkpoint['state_dict'] |
|
unwrap_model(model).load_state_dict(checkpoint, strict=False) |
|
print('=================== loaded from', path) |
|
|
|
def save_latest(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, model_ema=None): |
|
save_state = {'model': model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'lr_scheduler': lr_scheduler.state_dict(), |
|
'max_accuracy': max_accuracy, |
|
'epoch': epoch, |
|
'config': config} |
|
if config.AMP_OPT_LEVEL != "O0": |
|
save_state['amp'] = amp.state_dict() |
|
if model_ema is not None: |
|
save_state['ema'] = unwrap_model(model_ema).state_dict() |
|
|
|
save_path = os.path.join(config.OUTPUT, 'latest.pth') |
|
logger.info(f"{save_path} saving......") |
|
torch.save(save_state, save_path) |
|
logger.info(f"{save_path} saved !!!") |
|
|
|
def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, is_best=False, model_ema=None): |
|
save_state = {'model': model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'lr_scheduler': lr_scheduler.state_dict(), |
|
'max_accuracy': max_accuracy, |
|
'epoch': epoch, |
|
'config': config} |
|
if config.AMP_OPT_LEVEL != "O0": |
|
save_state['amp'] = amp.state_dict() |
|
if model_ema is not None: |
|
save_state['ema'] = unwrap_model(model_ema).state_dict() |
|
|
|
if is_best: |
|
best_path = os.path.join(config.OUTPUT, 'best_ckpt.pth') |
|
torch.save(save_state, best_path) |
|
|
|
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') |
|
logger.info(f"{save_path} saving......") |
|
torch.save(save_state, save_path) |
|
logger.info(f"{save_path} saved !!!") |
|
|
|
|
|
def get_grad_norm(parameters, norm_type=2): |
|
if isinstance(parameters, torch.Tensor): |
|
parameters = [parameters] |
|
parameters = list(filter(lambda p: p.grad is not None, parameters)) |
|
norm_type = float(norm_type) |
|
total_norm = 0 |
|
for p in parameters: |
|
param_norm = p.grad.data.norm(norm_type) |
|
total_norm += param_norm.item() ** norm_type |
|
total_norm = total_norm ** (1. / norm_type) |
|
return total_norm |
|
|
|
|
|
import torch.distributed as dist |
|
|
|
def auto_resume_helper(output_dir): |
|
checkpoints = os.listdir(output_dir) |
|
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth') and 'ema' not in ckpt] |
|
print(f"All checkpoints founded in {output_dir}: {checkpoints}") |
|
if len(checkpoints) > 0: |
|
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) |
|
print(f"The latest checkpoint founded: {latest_checkpoint}") |
|
resume_file = latest_checkpoint |
|
else: |
|
resume_file = None |
|
return resume_file |
|
|
|
|
|
def reduce_tensor(tensor): |
|
rt = tensor.clone() |
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
|
rt /= dist.get_world_size() |
|
return rt |
|
|
|
def update_model_ema(cfg, num_gpus, model, model_ema, cur_epoch, cur_iter): |
|
"""Update exponential moving average (ema) of model weights.""" |
|
update_period = cfg.TRAIN.EMA_UPDATE_PERIOD |
|
if update_period is None or update_period == 0 or cur_iter % update_period != 0: |
|
return |
|
|
|
total_batch_size = num_gpus * cfg.DATA.BATCH_SIZE |
|
adjust = total_batch_size / cfg.TRAIN.EPOCHS * update_period |
|
|
|
alpha = min(1.0, cfg.TRAIN.EMA_ALPHA * adjust) |
|
|
|
alpha = 1.0 if cur_epoch < cfg.TRAIN.WARMUP_EPOCHS else alpha |
|
|
|
params = unwrap_model(model).state_dict() |
|
for name, param in unwrap_model(model_ema).state_dict().items(): |
|
param.copy_(param * (1.0 - alpha) + params[name] * alpha) |