"""Console logger utilities. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging """ import logging import math import fsspec import lightning import torch from timm.scheduler import CosineLRScheduler def fsspec_exists(filename): """Check if a file exists using fsspec.""" fs, _ = fsspec.core.url_to_fs(filename) return fs.exists(filename) def fsspec_listdir(dirname): """Listdir in manner compatible with fsspec.""" fs, _ = fsspec.core.url_to_fs(dirname) return fs.ls(dirname) def fsspec_mkdirs(dirname, exist_ok=True): """Mkdirs in manner compatible with fsspec.""" fs, _ = fsspec.core.url_to_fs(dirname) fs.makedirs(dirname, exist_ok=exist_ok) def print_nans(tensor, name): if torch.isnan(tensor).any(): print(name, tensor) class CosineDecayWarmupLRScheduler( CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): """Wrap timm.scheduler.CosineLRScheduler Enables calling scheduler.step() without passing in epoch. Supports resuming as well. Adapted from: https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._last_epoch = -1 self.step(epoch=0) def step(self, epoch=None): if epoch is None: self._last_epoch += 1 else: self._last_epoch = epoch # We call either step or step_update, depending on # whether we're using the scheduler every epoch or every # step. # Otherwise, lightning will always call step (i.e., # meant for each epoch), and if we set scheduler # interval to "step", then the learning rate update will # be wrong. if self.t_in_epochs: super().step(epoch=self._last_epoch) else: super().step_update(num_updates=self._last_epoch) class LoggingContext: """Context manager for selective logging.""" def __init__(self, logger, level=None, handler=None, close=True): self.logger = logger self.level = level self.handler = handler self.close = close def __enter__(self): if self.level is not None: self.old_level = self.logger.level self.logger.setLevel(self.level) if self.handler: self.logger.addHandler(self.handler) def __exit__(self, et, ev, tb): if self.level is not None: self.logger.setLevel(self.old_level) if self.handler: self.logger.removeHandler(self.handler) if self.handler and self.close: self.handler.close() def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: """Initializes multi-GPU-friendly python logger.""" logger = logging.getLogger(name) logger.setLevel(level) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup for level in ('debug', 'info', 'warning', 'error', 'exception', 'fatal', 'critical'): setattr(logger, level, lightning.pytorch.utilities.rank_zero_only( getattr(logger, level))) return logger class Sampler: def __init__(self, shape): self.shape = shape def _sampling_noise(self): pass def _hard_sample(self, logits): pass def _soft_sample(self, logits): return 0 def sample(self, logits): noise = self._sampling_noise() noise = noise[: logits.shape[0], :] logits = logits + noise.to( dtype=logits.dtype, device=logits.device) hard_sample = self._hard_sample(logits) soft_sample = self._soft_sample(logits) return soft_sample + (hard_sample - soft_sample).detach() class TopKSampler(Sampler): def __init__(self, k, shape, gamma_tau=1.0): super().__init__(shape) self.k = k self.gamma_tau = gamma_tau self.num_betas = 10 self.sampler = torch.distributions.gamma.Gamma( 1 / k * torch.ones(self.num_betas, * self.shape), 1.0) def _sampling_noise(self): noise = self.sampler.sample() beta = self.k / torch.arange(1, self.num_betas + 1, 1, dtype=torch.float32) beta = beta[:, None, None] assert beta.ndim == noise.ndim s = noise / beta s = torch.sum(s, axis=0) s = s - math.log(10.0) s = self.gamma_tau * (s / self.k) return s def _hard_sample(self, logits): assert logits.ndim == 2 thresholds, _ = torch.sort(logits, dim=-1) thresholds = thresholds[:, - self.k][:, None] return (logits >= thresholds).type(logits.dtype) def _soft_sample(self, logits): soft_top_k = logits - torch.mean(logits, dim=-1, keepdim=True) return soft_top_k / torch.norm(soft_top_k, dim=-1, keepdim=True) class DeterministicTopK(TopKSampler): def __init__(self, k): super().__init__(k, shape=(1, 1)) def _sampling_noise(self): return 0 def discreize(self, x): hard_sample = self._hard_sample(x) soft_sample = self._soft_sample(x) return soft_sample + (hard_sample - soft_sample).detach() class GumbelSampler(Sampler): def __init__(self, shape, temperature=1.0): super().__init__(shape) self.temperature = temperature def _sampling_noise(self): return - (1e-10 - ( torch.rand(* self.shape) + 1e-10).log()).log() def _hard_sample(self, logits): assert logits.ndim == 2 indices = torch.argmax(logits, dim=-1) zeros = logits * 0 ones = torch.ones_like(logits[:, :, :1]) return torch.scatter(zeros, -1, indices[:, :, None], ones) def _soft_sample(self, logits): return torch.nn.functional.softmax( logits / self.temperature, dim=-1) class BinarySampler(GumbelSampler): def sample(self, probs): # TODO(subhamsahoo): use the temperature parameter. pos_noise = self._sampling_noise().to( dtype=probs.dtype, device=probs.device) neg_noise = self._sampling_noise().to( dtype=probs.dtype, device=probs.device) del_noise_exp = (neg_noise - pos_noise).exp() hard_sample = (probs * (1 + del_noise_exp) > 1).to(probs.dtype) soft_sample = probs / (probs + (1 - probs) * del_noise_exp) return soft_sample + (hard_sample - soft_sample).detach() class GaussianSampler: def __init__(self): self.softplus = torch.nn.Softplus() def sample(self, x): assert x.ndim == 2 n = x.shape[-1] // 2 mu = x[:, :n] sigma = self.softplus(x[:, n:]).sqrt() return mu + sigma * torch.randn_like(mu)