File size: 707 Bytes
45b4aa7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
# Copyright (c) OpenMMLab. All rights reserved.
from ..dist_utils import allreduce_params
from .hook import HOOKS, Hook
@HOOKS.register_module()
class SyncBuffersHook(Hook):
"""Synchronize model buffers such as running_mean and running_var in BN at
the end of each epoch.
Args:
distributed (bool): Whether distributed training is used. It is
effective only for distributed training. Defaults to True.
"""
def __init__(self, distributed=True):
self.distributed = distributed
def after_epoch(self, runner):
"""All-reduce model buffers at the end of each epoch."""
if self.distributed:
allreduce_params(runner.model.buffers())
|