MeMDLM / utils.py
sgoel30's picture
Upload 12 files
d061944 verified
raw
history blame
6.64 kB
"""Console logger utilities.
Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
"""
import logging
import math
import fsspec
import lightning
import torch
from timm.scheduler import CosineLRScheduler
def fsspec_exists(filename):
"""Check if a file exists using fsspec."""
fs, _ = fsspec.core.url_to_fs(filename)
return fs.exists(filename)
def fsspec_listdir(dirname):
"""Listdir in manner compatible with fsspec."""
fs, _ = fsspec.core.url_to_fs(dirname)
return fs.ls(dirname)
def fsspec_mkdirs(dirname, exist_ok=True):
"""Mkdirs in manner compatible with fsspec."""
fs, _ = fsspec.core.url_to_fs(dirname)
fs.makedirs(dirname, exist_ok=exist_ok)
def print_nans(tensor, name):
if torch.isnan(tensor).any():
print(name, tensor)
class CosineDecayWarmupLRScheduler(
CosineLRScheduler,
torch.optim.lr_scheduler._LRScheduler):
"""Wrap timm.scheduler.CosineLRScheduler
Enables calling scheduler.step() without passing in epoch.
Supports resuming as well.
Adapted from:
https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._last_epoch = -1
self.step(epoch=0)
def step(self, epoch=None):
if epoch is None:
self._last_epoch += 1
else:
self._last_epoch = epoch
# We call either step or step_update, depending on
# whether we're using the scheduler every epoch or every
# step.
# Otherwise, lightning will always call step (i.e.,
# meant for each epoch), and if we set scheduler
# interval to "step", then the learning rate update will
# be wrong.
if self.t_in_epochs:
super().step(epoch=self._last_epoch)
else:
super().step_update(num_updates=self._last_epoch)
class LoggingContext:
"""Context manager for selective logging."""
def __init__(self, logger, level=None, handler=None, close=True):
self.logger = logger
self.level = level
self.handler = handler
self.close = close
def __enter__(self):
if self.level is not None:
self.old_level = self.logger.level
self.logger.setLevel(self.level)
if self.handler:
self.logger.addHandler(self.handler)
def __exit__(self, et, ev, tb):
if self.level is not None:
self.logger.setLevel(self.old_level)
if self.handler:
self.logger.removeHandler(self.handler)
if self.handler and self.close:
self.handler.close()
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
"""Initializes multi-GPU-friendly python logger."""
logger = logging.getLogger(name)
logger.setLevel(level)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
for level in ('debug', 'info', 'warning', 'error',
'exception', 'fatal', 'critical'):
setattr(logger,
level,
lightning.pytorch.utilities.rank_zero_only(
getattr(logger, level)))
return logger
class Sampler:
def __init__(self, shape):
self.shape = shape
def _sampling_noise(self):
pass
def _hard_sample(self, logits):
pass
def _soft_sample(self, logits):
return 0
def sample(self, logits):
noise = self._sampling_noise()
noise = noise[: logits.shape[0], :]
logits = logits + noise.to(
dtype=logits.dtype, device=logits.device)
hard_sample = self._hard_sample(logits)
soft_sample = self._soft_sample(logits)
return soft_sample + (hard_sample - soft_sample).detach()
class TopKSampler(Sampler):
def __init__(self, k, shape, gamma_tau=1.0):
super().__init__(shape)
self.k = k
self.gamma_tau = gamma_tau
self.num_betas = 10
self.sampler = torch.distributions.gamma.Gamma(
1 / k * torch.ones(self.num_betas, * self.shape), 1.0)
def _sampling_noise(self):
noise = self.sampler.sample()
beta = self.k / torch.arange(1, self.num_betas + 1, 1,
dtype=torch.float32)
beta = beta[:, None, None]
assert beta.ndim == noise.ndim
s = noise / beta
s = torch.sum(s, axis=0)
s = s - math.log(10.0)
s = self.gamma_tau * (s / self.k)
return s
def _hard_sample(self, logits):
assert logits.ndim == 2
thresholds, _ = torch.sort(logits, dim=-1)
thresholds = thresholds[:, - self.k][:, None]
return (logits >= thresholds).type(logits.dtype)
def _soft_sample(self, logits):
soft_top_k = logits - torch.mean(logits, dim=-1,
keepdim=True)
return soft_top_k / torch.norm(soft_top_k, dim=-1,
keepdim=True)
class DeterministicTopK(TopKSampler):
def __init__(self, k):
super().__init__(k, shape=(1, 1))
def _sampling_noise(self):
return 0
def discreize(self, x):
hard_sample = self._hard_sample(x)
soft_sample = self._soft_sample(x)
return soft_sample + (hard_sample - soft_sample).detach()
class GumbelSampler(Sampler):
def __init__(self, shape, temperature=1.0):
super().__init__(shape)
self.temperature = temperature
def _sampling_noise(self):
return - (1e-10 - (
torch.rand(* self.shape) + 1e-10).log()).log()
def _hard_sample(self, logits):
assert logits.ndim == 2
indices = torch.argmax(logits, dim=-1)
zeros = logits * 0
ones = torch.ones_like(logits[:, :, :1])
return torch.scatter(zeros, -1, indices[:, :, None],
ones)
def _soft_sample(self, logits):
return torch.nn.functional.softmax(
logits / self.temperature, dim=-1)
class BinarySampler(GumbelSampler):
def sample(self, probs):
# TODO(subhamsahoo): use the temperature parameter.
pos_noise = self._sampling_noise().to(
dtype=probs.dtype, device=probs.device)
neg_noise = self._sampling_noise().to(
dtype=probs.dtype, device=probs.device)
del_noise_exp = (neg_noise - pos_noise).exp()
hard_sample = (probs * (1 + del_noise_exp)
> 1).to(probs.dtype)
soft_sample = probs / (probs + (1 - probs) * del_noise_exp)
return soft_sample + (hard_sample - soft_sample).detach()
class GaussianSampler:
def __init__(self):
self.softplus = torch.nn.Softplus()
def sample(self, x):
assert x.ndim == 2
n = x.shape[-1] // 2
mu = x[:, :n]
sigma = self.softplus(x[:, n:]).sqrt()
return mu + sigma * torch.randn_like(mu)