zdou0830's picture
desco
749745d
raw
history blame
1.69 kB
from copy import deepcopy
from collections import OrderedDict
import torch
class ModelEma:
def __init__(self, model, decay=0.9999, device=""):
self.ema = deepcopy(model)
self.ema.eval()
self.decay = decay
self.device = device
if device:
self.ema.to(device=device)
self.ema_is_dp = hasattr(self.ema, "module")
for p in self.ema.parameters():
p.requires_grad_(False)
def load_checkpoint(self, checkpoint):
if isinstance(checkpoint, str):
checkpoint = torch.load(checkpoint)
assert isinstance(checkpoint, dict)
if "model_ema" in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint["model_ema"].items():
if self.ema_is_dp:
name = k if k.startswith("module") else "module." + k
else:
name = k.replace("module.", "") if k.startswith("module") else k
new_state_dict[name] = v
self.ema.load_state_dict(new_state_dict)
def state_dict(self):
return self.ema.state_dict()
def update(self, model):
pre_module = hasattr(model, "module") and not self.ema_is_dp
with torch.no_grad():
curr_msd = model.state_dict()
for k, ema_v in self.ema.state_dict().items():
k = "module." + k if pre_module else k
model_v = curr_msd[k].detach()
if self.device:
model_v = model_v.to(device=self.device)
ema_v.copy_(ema_v * self.decay + (1.0 - self.decay) * model_v)