Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Used for EMA tracking a given pytorch module. The user is responsible for calling step() | |
and setting the appropriate decay | |
""" | |
import copy | |
import logging | |
import torch | |
class EMAModule: | |
"""Exponential Moving Average of Fairseq Models""" | |
def __init__( | |
self, model, ema_decay=0.9999, ema_fp32=False, device=None, skip_keys=None | |
): | |
""" | |
@param model model to initialize the EMA with | |
@param config EMAConfig object with configuration like | |
ema_decay, ema_update_freq, ema_fp32 | |
@param device If provided, copy EMA to this device (e.g. gpu). | |
Otherwise EMA is in the same device as the model. | |
""" | |
self.decay = ema_decay | |
self.ema_fp32 = ema_fp32 | |
self.model = copy.deepcopy(model) | |
self.model.requires_grad_(False) | |
self.skip_keys = skip_keys or set() | |
self.fp32_params = {} | |
if device is not None: | |
logging.info(f"Copying EMA model to device {device}") | |
self.model = self.model.to(device=device) | |
if self.ema_fp32: | |
self.build_fp32_params() | |
self.update_freq_counter = 0 | |
def build_fp32_params(self, state_dict=None): | |
""" | |
Store a copy of the EMA params in fp32. | |
If state dict is passed, the EMA params is copied from | |
the provided state dict. Otherwise, it is copied from the | |
current EMA model parameters. | |
""" | |
if not self.ema_fp32: | |
raise RuntimeError( | |
"build_fp32_params should not be called if ema_fp32=False. " | |
"Use ema_fp32=True if this is really intended." | |
) | |
if state_dict is None: | |
state_dict = self.model.state_dict() | |
def _to_float(t): | |
return t.float() if torch.is_floating_point(t) else t | |
for param_key in state_dict: | |
if param_key in self.fp32_params: | |
self.fp32_params[param_key].copy_(state_dict[param_key]) | |
else: | |
self.fp32_params[param_key] = _to_float(state_dict[param_key]) | |
def restore(self, state_dict, build_fp32_params=False): | |
"""Load data from a model spec into EMA model""" | |
self.model.load_state_dict(state_dict, strict=False) | |
if build_fp32_params: | |
self.build_fp32_params(state_dict) | |
def set_decay(self, decay): | |
self.decay = decay | |
def get_decay(self): | |
return self.decay | |
def _step_internal(self, new_model): | |
"""One update of the EMA model based on new model weights""" | |
decay = self.decay | |
ema_state_dict = {} | |
ema_params = self.fp32_params if self.ema_fp32 else self.model.state_dict() | |
for key, param in new_model.state_dict().items(): | |
if isinstance(param, dict): | |
continue | |
try: | |
ema_param = ema_params[key] | |
except KeyError: | |
ema_param = ( | |
param.float().clone() if param.ndim == 1 else copy.deepcopy(param) | |
) | |
if param.shape != ema_param.shape: | |
raise ValueError( | |
"incompatible tensor shapes between model param and ema param" | |
+ "{} vs. {}".format(param.shape, ema_param.shape) | |
) | |
if "version" in key: | |
# Do not decay a model.version pytorch param | |
continue | |
if key in self.skip_keys or ( | |
"num_batches_tracked" in key and ema_param.dtype == torch.int64 | |
): | |
ema_param = param.to(dtype=ema_param.dtype).clone() | |
ema_params[key].copy_(ema_param) | |
else: | |
ema_param.mul_(decay) | |
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay) | |
ema_state_dict[key] = ema_param | |
self.restore(ema_state_dict, build_fp32_params=False) | |
def step(self, new_model): | |
self._step_internal(new_model) | |
def reverse(self, model): | |
""" | |
Load the model parameters from EMA model. | |
Useful for inference or fine-tuning from the EMA model. | |
""" | |
d = self.model.state_dict() | |
if "_ema" in d: | |
del d["_ema"] | |
model.load_state_dict(d, strict=False) | |
return model | |