File size: 707 Bytes
45b4aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright (c) OpenMMLab. All rights reserved.
from ..dist_utils import allreduce_params
from .hook import HOOKS, Hook


@HOOKS.register_module()
class SyncBuffersHook(Hook):
    """Synchronize model buffers such as running_mean and running_var in BN at
    the end of each epoch.

    Args:
        distributed (bool): Whether distributed training is used. It is
          effective only for distributed training. Defaults to True.
    """

    def __init__(self, distributed=True):
        self.distributed = distributed

    def after_epoch(self, runner):
        """All-reduce model buffers at the end of each epoch."""
        if self.distributed:
            allreduce_params(runner.model.buffers())