RepVGG / RepVGG-main /utils.py
yuxi-liu-wired's picture
init
0decf42
# --------------------------------------------------------
# RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
# Github source: https://github.com/DingXiaoH/RepVGG
# Licensed under The MIT License [see LICENSE for details]
# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer)
# --------------------------------------------------------
import torch
import math
import os
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def load_checkpoint(model, ckpt_path):
checkpoint = torch.load(ckpt_path)
if 'model' in checkpoint:
checkpoint = checkpoint['model']
if 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
ckpt = {}
for k, v in checkpoint.items():
if k.startswith('module.'):
ckpt[k[7:]] = v
else:
ckpt[k] = v
model.load_state_dict(ckpt)
class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, T_cosine_max, eta_min=0, last_epoch=-1, warmup=0):
self.eta_min = eta_min
self.T_cosine_max = T_cosine_max
self.warmup = warmup
super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup:
return [self.last_epoch / self.warmup * base_lr for base_lr in self.base_lrs]
else:
return [self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos(math.pi * (self.last_epoch - self.warmup) / (self.T_cosine_max - self.warmup))) / 2
for base_lr in self.base_lrs]
def log_msg(message, log_file):
print(message)
with open(log_file, 'a') as f:
print(message, file=f)
try:
# noinspection PyUnresolvedReferences
from apex import amp
except ImportError:
amp = None
def unwrap_model(model):
"""Remove the DistributedDataParallel wrapper if present."""
wrapped = isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel)
return model.module if wrapped else model
def load_checkpoint(config, model, optimizer, lr_scheduler, logger, model_ema=None):
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
if config.MODEL.RESUME.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
config.MODEL.RESUME, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
msg = model.load_state_dict(checkpoint['model'], strict=False)
logger.info(msg)
max_accuracy = 0.0
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
config.defrost()
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
config.freeze()
if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
amp.load_state_dict(checkpoint['amp'])
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
if 'max_accuracy' in checkpoint:
max_accuracy = checkpoint['max_accuracy']
if model_ema is not None:
unwrap_model(model_ema).load_state_dict(checkpoint['ema'])
print('=================================================== EMAloaded')
del checkpoint
torch.cuda.empty_cache()
return max_accuracy
def load_weights(model, path):
checkpoint = torch.load(path, map_location='cpu')
if 'model' in checkpoint:
checkpoint = checkpoint['model']
if 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
unwrap_model(model).load_state_dict(checkpoint, strict=False)
print('=================== loaded from', path)
def save_latest(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, model_ema=None):
save_state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'epoch': epoch,
'config': config}
if config.AMP_OPT_LEVEL != "O0":
save_state['amp'] = amp.state_dict()
if model_ema is not None:
save_state['ema'] = unwrap_model(model_ema).state_dict()
save_path = os.path.join(config.OUTPUT, 'latest.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, is_best=False, model_ema=None):
save_state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'epoch': epoch,
'config': config}
if config.AMP_OPT_LEVEL != "O0":
save_state['amp'] = amp.state_dict()
if model_ema is not None:
save_state['ema'] = unwrap_model(model_ema).state_dict()
if is_best:
best_path = os.path.join(config.OUTPUT, 'best_ckpt.pth')
torch.save(save_state, best_path)
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
def get_grad_norm(parameters, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
return total_norm
import torch.distributed as dist
def auto_resume_helper(output_dir):
checkpoints = os.listdir(output_dir)
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth') and 'ema' not in ckpt]
print(f"All checkpoints founded in {output_dir}: {checkpoints}")
if len(checkpoints) > 0:
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
print(f"The latest checkpoint founded: {latest_checkpoint}")
resume_file = latest_checkpoint
else:
resume_file = None
return resume_file
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt
def update_model_ema(cfg, num_gpus, model, model_ema, cur_epoch, cur_iter):
"""Update exponential moving average (ema) of model weights."""
update_period = cfg.TRAIN.EMA_UPDATE_PERIOD
if update_period is None or update_period == 0 or cur_iter % update_period != 0:
return
# Adjust alpha to be fairly independent of other parameters
total_batch_size = num_gpus * cfg.DATA.BATCH_SIZE
adjust = total_batch_size / cfg.TRAIN.EPOCHS * update_period
# print('ema adjust', adjust)
alpha = min(1.0, cfg.TRAIN.EMA_ALPHA * adjust)
# During warmup simply copy over weights instead of using ema
alpha = 1.0 if cur_epoch < cfg.TRAIN.WARMUP_EPOCHS else alpha
# Take ema of all parameters (not just named parameters)
params = unwrap_model(model).state_dict()
for name, param in unwrap_model(model_ema).state_dict().items():
param.copy_(param * (1.0 - alpha) + params[name] * alpha)