# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # This work is licensed under a Creative Commons # Attribution-NonCommercial-ShareAlike 4.0 International License. # You should have received a copy of the license along with this # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ """Facilities for reporting and collecting training statistics across multiple processes and devices. The interface is designed to minimize synchronization overhead as well as the amount of boilerplate in user code.""" import re import numpy as np import torch import dnnlib from . import misc #---------------------------------------------------------------------------- _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. _counter_dtype = torch.float64 # Data type to use for the internal counters. _rank = 0 # Rank of the current process. _sync_device = None # Device to use for multiprocess communication. None = single-process. _sync_called = False # Has _sync() been called yet? _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor #---------------------------------------------------------------------------- def init_multiprocessing(rank, sync_device): r"""Initializes `torch_utils.training_stats` for collecting statistics across multiple processes. This function must be called after `torch.distributed.init_process_group()` and before `Collector.update()`. The call is not necessary if multi-process collection is not needed. Args: rank: Rank of the current process. sync_device: PyTorch device to use for inter-process communication, or None to disable multi-process collection. Typically `torch.device('cuda', rank)`. """ global _rank, _sync_device assert not _sync_called _rank = rank _sync_device = sync_device #---------------------------------------------------------------------------- @misc.profiled_function def report(name, value): r"""Broadcasts the given set of scalars to all interested instances of `Collector`, across device and process boundaries. NaNs and Infs are ignored. This function is expected to be extremely cheap and can be safely called from anywhere in the training loop, loss function, or inside a `torch.nn.Module`. Warning: The current implementation expects the set of unique names to be consistent across processes. Please make sure that `report()` is called at least once for each unique name by each process, and in the same order. If a given process has no scalars to broadcast, it can do `report(name, [])` (empty list). Args: name: Arbitrary string specifying the name of the statistic. Averages are accumulated separately for each unique name. value: Arbitrary set of scalars. Can be a list, tuple, NumPy array, PyTorch tensor, or Python scalar. Returns: The same `value` that was passed in. """ if name not in _counters: _counters[name] = dict() elems = torch.as_tensor(value) if elems.numel() == 0: return value elems = elems.detach().flatten().to(_reduce_dtype) square = elems.square() finite = square.isfinite() moments = torch.stack([ finite.sum(dtype=_reduce_dtype), torch.where(finite, elems, 0).sum(), torch.where(finite, square, 0).sum(), ]) assert moments.ndim == 1 and moments.shape[0] == _num_moments moments = moments.to(_counter_dtype) device = moments.device if device not in _counters[name]: _counters[name][device] = torch.zeros_like(moments) _counters[name][device].add_(moments) return value #---------------------------------------------------------------------------- def report0(name, value): r"""Broadcasts the given set of scalars by the first process (`rank = 0`), but ignores any scalars provided by the other processes. See `report()` for further details. """ report(name, value if _rank == 0 else []) return value #---------------------------------------------------------------------------- class Collector: r"""Collects the scalars broadcasted by `report()` and `report0()` and computes their long-term averages (mean and standard deviation) over user-defined periods of time. The averages are first collected into internal counters that are not directly visible to the user. They are then copied to the user-visible state as a result of calling `update()` and can then be queried using `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the internal counters for the next round, so that the user-visible state effectively reflects averages collected between the last two calls to `update()`. Args: regex: Regular expression defining which statistics to collect. The default is to collect everything. keep_previous: Whether to retain the previous averages if no scalars were collected on a given round (default: False). """ def __init__(self, regex='.*', keep_previous=False): self._regex = re.compile(regex) self._keep_previous = keep_previous self._cumulative = dict() self._moments = dict() self.update() self._moments.clear() def names(self): r"""Returns the names of all statistics broadcasted so far that match the regular expression specified at construction time. """ return [name for name in _counters if self._regex.fullmatch(name)] def update(self): r"""Copies current values of the internal counters to the user-visible state and resets them for the next round. If `keep_previous=True` was specified at construction time, the operation is skipped for statistics that have received no scalars since the last update, retaining their previous averages. This method performs a number of GPU-to-CPU transfers and one `torch.distributed.all_reduce()`. It is intended to be called periodically in the main training loop, typically once every N training steps. """ if not self._keep_previous: self._moments.clear() for name, cumulative in _sync(self.names()): if name not in self._cumulative: self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) delta = cumulative - self._cumulative[name] self._cumulative[name].copy_(cumulative) if float(delta[0]) != 0: self._moments[name] = delta def _get_delta(self, name): r"""Returns the raw moments that were accumulated for the given statistic between the last two calls to `update()`, or zero if no scalars were collected. """ assert self._regex.fullmatch(name) if name not in self._moments: self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) return self._moments[name] def num(self, name): r"""Returns the number of scalars that were accumulated for the given statistic between the last two calls to `update()`, or zero if no scalars were collected. """ delta = self._get_delta(name) return int(delta[0]) def mean(self, name): r"""Returns the mean of the scalars that were accumulated for the given statistic between the last two calls to `update()`, or NaN if no scalars were collected. """ delta = self._get_delta(name) if int(delta[0]) == 0: return float('nan') return float(delta[1] / delta[0]) def std(self, name): r"""Returns the standard deviation of the scalars that were accumulated for the given statistic between the last two calls to `update()`, or NaN if no scalars were collected. """ delta = self._get_delta(name) if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): return float('nan') if int(delta[0]) == 1: return float(0) mean = float(delta[1] / delta[0]) raw_var = float(delta[2] / delta[0]) return np.sqrt(max(raw_var - np.square(mean), 0)) def as_dict(self): r"""Returns the averages accumulated between the last two calls to `update()` as an `dnnlib.EasyDict`. The contents are as follows: dnnlib.EasyDict( NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), ... ) """ stats = dnnlib.EasyDict() for name in self.names(): stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) return stats def __getitem__(self, name): r"""Convenience getter. `collector[name]` is a synonym for `collector.mean(name)`. """ return self.mean(name) #---------------------------------------------------------------------------- def _sync(names): r"""Synchronize the global cumulative counters across devices and processes. Called internally by `Collector.update()`. """ if len(names) == 0: return [] global _sync_called _sync_called = True # Check that all ranks have the same set of names. if _sync_device is not None: value = hash(tuple(tuple(ord(char) for char in name) for name in names)) other = torch.as_tensor(value, dtype=torch.int64, device=_sync_device) torch.distributed.broadcast(tensor=other, src=0) if value != int(other.cpu()): raise ValueError('Training statistics are inconsistent between ranks') # Collect deltas within current rank. deltas = [] device = _sync_device if _sync_device is not None else torch.device('cpu') for name in names: delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) for counter in _counters[name].values(): delta.add_(counter.to(device)) counter.copy_(torch.zeros_like(counter)) deltas.append(delta) deltas = torch.stack(deltas) # Sum deltas across ranks. if _sync_device is not None: torch.distributed.all_reduce(deltas) # Update cumulative values. deltas = deltas.cpu() for idx, name in enumerate(names): if name not in _cumulative: _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) _cumulative[name].add_(deltas[idx]) # Return name-value pairs. return [(name, _cumulative[name]) for name in names] #---------------------------------------------------------------------------- # Convenience. default_collector = Collector() #----------------------------------------------------------------------------