# Copyright (c) OpenMMLab. All rights reserved. | |
from ..dist_utils import allreduce_params | |
from .hook import HOOKS, Hook | |
class SyncBuffersHook(Hook): | |
"""Synchronize model buffers such as running_mean and running_var in BN at | |
the end of each epoch. | |
Args: | |
distributed (bool): Whether distributed training is used. It is | |
effective only for distributed training. Defaults to True. | |
""" | |
def __init__(self, distributed=True): | |
self.distributed = distributed | |
def after_epoch(self, runner): | |
"""All-reduce model buffers at the end of each epoch.""" | |
if self.distributed: | |
allreduce_params(runner.model.buffers()) | |