import torch |
import math |
import os |
class AverageMeter(object): |
"""Computes and stores the average and current value""" |
def __init__(self, name, fmt=':f'): |
self.name = name |
self.fmt = fmt |
self.reset() |
def reset(self): |
self.val = 0 |
self.avg = 0 |
self.sum = 0 |
self.count = 0 |
def update(self, val, n=1): |
self.val = val |
self.sum += val * n |
self.count += n |
self.avg = self.sum / self.count |
def __str__(self): |
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
return fmtstr.format(**self.__dict__) |
class ProgressMeter(object): |
def __init__(self, num_batches, meters, prefix=""): |
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
self.meters = meters |
self.prefix = prefix |
def display(self, batch): |
entries = [self.prefix + self.batch_fmtstr.format(batch)] |
entries += [str(meter) for meter in self.meters] |
print('\t'.join(entries)) |
def _get_batch_fmtstr(self, num_batches): |
num_digits = len(str(num_batches // 1)) |
fmt = '{:' + str(num_digits) + 'd}' |
return '[' + fmt + '/' + fmt.format(num_batches) + ']' |
def accuracy(output, target, topk=(1,)): |
"""Computes the accuracy over the k top predictions for the specified values of k""" |
with torch.no_grad(): |
maxk = max(topk) |
batch_size = target.size(0) |
_, pred = output.topk(maxk, 1, True, True) |
pred = pred.t() |
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
res = [] |
for k in topk: |
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) |
res.append(correct_k.mul_(100.0 / batch_size)) |
return res |
def load_checkpoint(model, ckpt_path): |
checkpoint = torch.load(ckpt_path) |
if 'model' in checkpoint: |
checkpoint = checkpoint['model'] |
if 'state_dict' in checkpoint: |
checkpoint = checkpoint['state_dict'] |
ckpt = {} |
for k, v in checkpoint.items(): |
if k.startswith('module.'): |
ckpt[k[7:]] = v |
else: |
ckpt[k] = v |
model.load_state_dict(ckpt) |
class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler): |
def __init__(self, optimizer, T_cosine_max, eta_min=0, last_epoch=-1, warmup=0): |
self.eta_min = eta_min |
self.T_cosine_max = T_cosine_max |
self.warmup = warmup |
super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) |
def get_lr(self): |
if self.last_epoch < self.warmup: |
return [self.last_epoch / self.warmup * base_lr for base_lr in self.base_lrs] |
else: |
return [self.eta_min + (base_lr - self.eta_min) * |
(1 + math.cos(math.pi * (self.last_epoch - self.warmup) / (self.T_cosine_max - self.warmup))) / 2 |
for base_lr in self.base_lrs] |
def log_msg(message, log_file): |
print(message) |
with open(log_file, 'a') as f: |
print(message, file=f) |
try: |
from apex import amp |
except ImportError: |
amp = None |
def unwrap_model(model): |
"""Remove the DistributedDataParallel wrapper if present.""" |
wrapped = isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel) |
return model.module if wrapped else model |
def load_checkpoint(config, model, optimizer, lr_scheduler, logger, model_ema=None): |
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") |
if config.MODEL.RESUME.startswith('https'): |
checkpoint = torch.hub.load_state_dict_from_url( |
config.MODEL.RESUME, map_location='cpu', check_hash=True) |
else: |
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') |
msg = model.load_state_dict(checkpoint['model'], strict=False) |
logger.info(msg) |
max_accuracy = 0.0 |
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: |
optimizer.load_state_dict(checkpoint['optimizer']) |
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
config.defrost() |
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 |
config.freeze() |
if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0": |
amp.load_state_dict(checkpoint['amp']) |
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") |
if 'max_accuracy' in checkpoint: |
max_accuracy = checkpoint['max_accuracy'] |
if model_ema is not None: |
unwrap_model(model_ema).load_state_dict(checkpoint['ema']) |
print('=================================================== EMAloaded') |
del checkpoint |
torch.cuda.empty_cache() |
return max_accuracy |
def load_weights(model, path): |
checkpoint = torch.load(path, map_location='cpu') |
if 'model' in checkpoint: |
checkpoint = checkpoint['model'] |
if 'state_dict' in checkpoint: |
checkpoint = checkpoint['state_dict'] |
unwrap_model(model).load_state_dict(checkpoint, strict=False) |
print('=================== loaded from', path) |
def save_latest(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, model_ema=None): |
save_state = {'model': model.state_dict(), |
'optimizer': optimizer.state_dict(), |
'lr_scheduler': lr_scheduler.state_dict(), |
'max_accuracy': max_accuracy, |
'epoch': epoch, |
'config': config} |
if config.AMP_OPT_LEVEL != "O0": |
save_state['amp'] = amp.state_dict() |
if model_ema is not None: |
save_state['ema'] = unwrap_model(model_ema).state_dict() |
save_path = os.path.join(config.OUTPUT, 'latest.pth') |
logger.info(f"{save_path} saving......") |
torch.save(save_state, save_path) |
logger.info(f"{save_path} saved !!!") |
def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, is_best=False, model_ema=None): |
save_state = {'model': model.state_dict(), |
'optimizer': optimizer.state_dict(), |
'lr_scheduler': lr_scheduler.state_dict(), |
'max_accuracy': max_accuracy, |
'epoch': epoch, |
'config': config} |
if config.AMP_OPT_LEVEL != "O0": |
save_state['amp'] = amp.state_dict() |
if model_ema is not None: |
save_state['ema'] = unwrap_model(model_ema).state_dict() |
if is_best: |
best_path = os.path.join(config.OUTPUT, 'best_ckpt.pth') |
torch.save(save_state, best_path) |
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') |
logger.info(f"{save_path} saving......") |
torch.save(save_state, save_path) |
logger.info(f"{save_path} saved !!!") |
def get_grad_norm(parameters, norm_type=2): |
if isinstance(parameters, torch.Tensor): |
parameters = [parameters] |
parameters = list(filter(lambda p: p.grad is not None, parameters)) |
norm_type = float(norm_type) |
total_norm = 0 |
for p in parameters: |
param_norm = p.grad.data.norm(norm_type) |
total_norm += param_norm.item() ** norm_type |
total_norm = total_norm ** (1. / norm_type) |
return total_norm |
import torch.distributed as dist |
def auto_resume_helper(output_dir): |
checkpoints = os.listdir(output_dir) |
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth') and 'ema' not in ckpt] |
print(f"All checkpoints founded in {output_dir}: {checkpoints}") |
if len(checkpoints) > 0: |
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) |
print(f"The latest checkpoint founded: {latest_checkpoint}") |
resume_file = latest_checkpoint |
else: |
resume_file = None |
return resume_file |
def reduce_tensor(tensor): |
rt = tensor.clone() |
dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
rt /= dist.get_world_size() |
return rt |
def update_model_ema(cfg, num_gpus, model, model_ema, cur_epoch, cur_iter): |
"""Update exponential moving average (ema) of model weights.""" |
update_period = cfg.TRAIN.EMA_UPDATE_PERIOD |
if update_period is None or update_period == 0 or cur_iter % update_period != 0: |
return |
total_batch_size = num_gpus * cfg.DATA.BATCH_SIZE |
adjust = total_batch_size / cfg.TRAIN.EPOCHS * update_period |
alpha = min(1.0, cfg.TRAIN.EMA_ALPHA * adjust) |
alpha = 1.0 if cur_epoch < cfg.TRAIN.WARMUP_EPOCHS else alpha |
params = unwrap_model(model).state_dict() |
for name, param in unwrap_model(model_ema).state_dict().items(): |
param.copy_(param * (1.0 - alpha) + params[name] * alpha) |