# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import builtins
import datetime
import os
import time
from collections import defaultdict, deque
from pathlib import Path
import json
import subprocess
import torch
import torch.distributed as dist
from typing import List, Dict, Tuple, Optional
from torch import Tensor
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.count += n
self.total += value * n
def synchronize_between_processes(self):
Warning: does not synchronize the deque!
if not is_dist_avail_and_initialized():
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
def global_avg(self):
return self.total / self.count
def max(self):
return max(self.deque)
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
"{}: {}".format(name, str(meter))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'time: {time}',
'data: {data}'
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
i, len(iterable), eta=eta_string,
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
i, len(iterable), eta=eta_string,
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def setup_for_distributed(is_master):
This function disables printing when not in master process
builtin_print = builtins.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
force = force or (get_world_size() > 8)
if is_master or force:
now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs)
builtins.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
args.dist_url = 'env://'
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
elif 'SLURM_PROCID' in os.environ:
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
addr = subprocess.getoutput(
'scontrol show hostname {} | head -n1'.format(node_list))
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29200')
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['LOCAL_SIZE'] = str(num_gpus)
args.dist_url = 'env://'
args.world_size = ntasks
args.rank = proc_id
args.gpu = proc_id % num_gpus
print('Not using distributed mode')
args.distributed = False
args.distributed = True
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
setup_for_distributed(args.rank == 0)
def clip_grad_norm_(
parameters, max_norm: float, norm_type: float = 2.0,
error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor:
r"""Clips gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float): max norm of the gradients
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
foreach (bool): use the faster foreach-based implementation.
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
fall back to the slow implementation for other device types.
Default: ``None``
Total norm of the parameter gradients (viewed as a single vector).
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
grads = [p.grad for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(grads) == 0:
return torch.tensor(0.)
first_device = grads[0].device
grouped_grads: Dict[Tuple[torch.device, torch.dtype], List[List[Tensor]]] \
= {(first_device, grads[0].dtype): [[g.detach() for g in grads]]}
norms = [torch.norm(g) for g in grads]
total_norm = torch.norm(torch.stack(norms))
clip_coef = max_norm / (total_norm + 1e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for ((device, _), [grads]) in grouped_grads.items():
if (foreach is None or foreach):
torch._foreach_mul_(grads, clip_coef_clamped.to(device)) # type: ignore[call-overload]
elif foreach:
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in grads:
return total_norm
class NativeScalerWithGradNormCount:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = clip_grad_norm_(parameters, clip_grad)
norm = get_grad_norm_(parameters)
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
return total_norm
def save_model(args, epoch, model, model_without_ddp, optimizer):
output_dir = Path(args.output_dir)
epoch_name = str(epoch)
# checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
checkpoint_paths = [output_dir / 'checkpoint.pth']
for checkpoint_path in checkpoint_paths:
to_save = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'args': args,
save_on_master(to_save, checkpoint_path)
def load_model(args, model_without_ddp, optimizer):
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
checkpoint = torch.load(args.resume, map_location='cpu')
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
args.start_epoch = checkpoint['epoch'] + 1
print("With optim & sched!")
def auto_load_model(args, model, model_without_ddp, optimizer):
output_dir = Path(args.output_dir)
# torch.amp
if args.auto_resume and len(args.resume) == 0:
import glob
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
latest_ckpt = -1
for ckpt in all_checkpoints:
t = ckpt.split('-')[-1].split('.')[0]
if t.isdigit():
latest_ckpt = max(int(t), latest_ckpt)
if latest_ckpt >= 0:
args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
print("Auto resume checkpoint: %s" % args.resume)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
checkpoint = torch.load(args.resume, map_location='cpu')
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
args.start_epoch = checkpoint['epoch'] + 1
print("With optim & sched!")
def all_reduce_mean(x):
world_size = get_world_size()
if world_size > 1:
x_reduce = torch.tensor(x).cuda()
x_reduce /= world_size
return x_reduce.item()
return x
def create_ds_config(args):
args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
with open(args.deepspeed_config, mode="w") as writer:
ds_config = {
"train_batch_size": args.batch_size * args.accum_iter * get_world_size(),
"train_micro_batch_size_per_gpu": args.batch_size,
"steps_per_print": 1000,
"optimizer": {
"type": "Adam",
"adam_w_mode": True,
"params": {
"lr": args.lr,
"weight_decay": args.weight_decay,
"bias_correction": True,
"betas": [
"eps": args.opt_eps
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
# "bf16": {
# "enabled": True
# },
"amp": {
"enabled": False,
"opt_level": "O2"
"flops_profiler": {
"enabled": True,
"profile_step": -1,
"module_depth": -1,
"top_modules": 1,
"detailed": True,
if args.clip_grad is not None:
ds_config.update({'gradient_clipping': args.clip_grad})
if args.zero_stage == 1:
ds_config.update({"zero_optimization": {"stage": args.zero_stage, "reduce_bucket_size": 5e8}})
elif args.zero_stage > 1:
raise NotImplementedError()
writer.write(json.dumps(ds_config, indent=2))
def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
parameter_group_names = {}
parameter_group_vars = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
group_name = "no_decay"
this_weight_decay = 0.
group_name = "decay"
this_weight_decay = weight_decay
if get_num_layer is not None:
layer_id = get_num_layer(name)
group_name = "layer_%d_%s" % (layer_id, group_name)
layer_id = None
if group_name not in parameter_group_names:
if get_layer_scale is not None:
scale = get_layer_scale(layer_id)
scale = 1.
parameter_group_names[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
parameter_group_vars[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
return list(parameter_group_vars.values())