Spaces:
Sleeping
Sleeping
from copy import deepcopy | |
from collections import OrderedDict | |
import torch | |
class ModelEma: | |
def __init__(self, model, decay=0.9999, device=""): | |
self.ema = deepcopy(model) | |
self.ema.eval() | |
self.decay = decay | |
self.device = device | |
if device: | |
self.ema.to(device=device) | |
self.ema_is_dp = hasattr(self.ema, "module") | |
for p in self.ema.parameters(): | |
p.requires_grad_(False) | |
def load_checkpoint(self, checkpoint): | |
if isinstance(checkpoint, str): | |
checkpoint = torch.load(checkpoint) | |
assert isinstance(checkpoint, dict) | |
if "model_ema" in checkpoint: | |
new_state_dict = OrderedDict() | |
for k, v in checkpoint["model_ema"].items(): | |
if self.ema_is_dp: | |
name = k if k.startswith("module") else "module." + k | |
else: | |
name = k.replace("module.", "") if k.startswith("module") else k | |
new_state_dict[name] = v | |
self.ema.load_state_dict(new_state_dict) | |
def state_dict(self): | |
return self.ema.state_dict() | |
def update(self, model): | |
pre_module = hasattr(model, "module") and not self.ema_is_dp | |
with torch.no_grad(): | |
curr_msd = model.state_dict() | |
for k, ema_v in self.ema.state_dict().items(): | |
k = "module." + k if pre_module else k | |
model_v = curr_msd[k].detach() | |
if self.device: | |
model_v = model_v.to(device=self.device) | |
ema_v.copy_(ema_v * self.decay + (1.0 - self.decay) * model_v) | |