# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional import torch.nn as nn from mmengine.hooks import Hook from mmengine.model import is_model_wrapper from mmengine.runner import Runner from mmdet.registry import HOOKS @HOOKS.register_module() class MeanTeacherHook(Hook): """Mean Teacher Hook. Mean Teacher is an efficient semi-supervised learning method in `Mean Teacher `_. This method requires two models with exactly the same structure, as the student model and the teacher model, respectively. The student model updates the parameters through gradient descent, and the teacher model updates the parameters through exponential moving average of the student model. Compared with the student model, the teacher model is smoother and accumulates more knowledge. Args: momentum (float): The momentum used for updating teacher's parameter. Teacher's parameter are updated with the formula: `teacher = (1-momentum) * teacher + momentum * student`. Defaults to 0.001. interval (int): Update teacher's parameter every interval iteration. Defaults to 1. skip_buffers (bool): Whether to skip the model buffers, such as batchnorm running stats (running_mean, running_var), it does not perform the ema operation. Default to True. """ def __init__(self, momentum: float = 0.001, interval: int = 1, skip_buffer=True) -> None: assert 0 < momentum < 1 self.momentum = momentum self.interval = interval self.skip_buffers = skip_buffer def before_train(self, runner: Runner) -> None: """To check that teacher model and student model exist.""" model = runner.model if is_model_wrapper(model): model = model.module assert hasattr(model, 'teacher') assert hasattr(model, 'student') # only do it at initial stage if runner.iter == 0: self.momentum_update(model, 1) def after_train_iter(self, runner: Runner, batch_idx: int, data_batch: Optional[dict] = None, outputs: Optional[dict] = None) -> None: """Update teacher's parameter every self.interval iterations.""" if (runner.iter + 1) % self.interval != 0: return model = runner.model if is_model_wrapper(model): model = model.module self.momentum_update(model, self.momentum) def momentum_update(self, model: nn.Module, momentum: float) -> None: """Compute the moving average of the parameters using exponential moving average.""" if self.skip_buffers: for (src_name, src_parm), (dst_name, dst_parm) in zip( model.student.named_parameters(), model.teacher.named_parameters()): dst_parm.data.mul_(1 - momentum).add_( src_parm.data, alpha=momentum) else: for (src_parm, dst_parm) in zip(model.student.state_dict().values(), model.teacher.state_dict().values()): # exclude num_tracking if dst_parm.dtype.is_floating_point: dst_parm.data.mul_(1 - momentum).add_( src_parm.data, alpha=momentum)