# Copyright (c) OpenMMLab. All rights reserved. import os import platform import numpy as np import pytest import torch import torch.distributed as dist import torch.nn as nn if platform.system() == 'Windows': import regex as re else: import re class TestSyncBN: def dist_init(self): rank = int(os.environ['SLURM_PROCID']) world_size = int(os.environ['SLURM_NTASKS']) local_rank = int(os.environ['SLURM_LOCALID']) node_list = str(os.environ['SLURM_NODELIST']) node_parts = re.findall('[0-9]+', node_list) os.environ['MASTER_ADDR'] = (f'{node_parts[1]}.{node_parts[2]}' + f'.{node_parts[3]}.{node_parts[4]}') os.environ['MASTER_PORT'] = '12341' os.environ['WORLD_SIZE'] = str(world_size) os.environ['RANK'] = str(rank) dist.init_process_group('nccl') torch.cuda.set_device(local_rank) def _test_syncbn_train(self, size=1, half=False): if 'SLURM_NTASKS' not in os.environ or int( os.environ['SLURM_NTASKS']) != 4: print('must run with slurm has 4 processes!\n' 'srun -p test --gres=gpu:4 -n4') return else: print('Running syncbn test') from mmcv.ops import SyncBatchNorm assert size in (1, 2, 4) if not dist.is_initialized(): self.dist_init() rank = dist.get_rank() torch.manual_seed(9) torch.cuda.manual_seed(9) self.x = torch.rand(16, 3, 2, 3).cuda() self.y_bp = torch.rand(16, 3, 2, 3).cuda() if half: self.x = self.x.half() self.y_bp = self.y_bp.half() dist.broadcast(self.x, src=0) dist.broadcast(self.y_bp, src=0) torch.cuda.synchronize() if size == 1: groups = [None, None, None, None] groups[0] = dist.new_group([0]) groups[1] = dist.new_group([1]) groups[2] = dist.new_group([2]) groups[3] = dist.new_group([3]) group = groups[rank] elif size == 2: groups = [None, None, None, None] groups[0] = groups[1] = dist.new_group([0, 1]) groups[2] = groups[3] = dist.new_group([2, 3]) group = groups[rank] elif size == 4: group = dist.group.WORLD syncbn = SyncBatchNorm(3, group=group).cuda() syncbn.weight.data[0] = 0.2 syncbn.weight.data[1] = 0.5 syncbn.weight.data[2] = 0.7 syncbn.train() bn = nn.BatchNorm2d(3).cuda() bn.weight.data[0] = 0.2 bn.weight.data[1] = 0.5 bn.weight.data[2] = 0.7 bn.train() sx = self.x[rank * 4:rank * 4 + 4] sx.requires_grad_() sy = syncbn(sx) sy.backward(self.y_bp[rank * 4:rank * 4 + 4]) smean = syncbn.running_mean svar = syncbn.running_var sx_grad = sx.grad sw_grad = syncbn.weight.grad sb_grad = syncbn.bias.grad if size == 1: x = self.x[rank * 4:rank * 4 + 4] y_bp = self.y_bp[rank * 4:rank * 4 + 4] elif size == 2: x = self.x[rank // 2 * 8:rank // 2 * 8 + 8] y_bp = self.y_bp[rank // 2 * 8:rank // 2 * 8 + 8] elif size == 4: x = self.x y_bp = self.y_bp x.requires_grad_() y = bn(x) y.backward(y_bp) if size == 2: y = y[rank % 2 * 4:rank % 2 * 4 + 4] elif size == 4: y = y[rank * 4:rank * 4 + 4] mean = bn.running_mean var = bn.running_var if size == 1: x_grad = x.grad w_grad = bn.weight.grad b_grad = bn.bias.grad elif size == 2: x_grad = x.grad[rank % 2 * 4:rank % 2 * 4 + 4] w_grad = bn.weight.grad / 2 b_grad = bn.bias.grad / 2 elif size == 4: x_grad = x.grad[rank * 4:rank * 4 + 4] w_grad = bn.weight.grad / 4 b_grad = bn.bias.grad / 4 assert np.allclose(mean.data.cpu().numpy(), smean.data.cpu().numpy(), 1e-3) assert np.allclose(var.data.cpu().numpy(), svar.data.cpu().numpy(), 1e-3) assert np.allclose(y.data.cpu().numpy(), sy.data.cpu().numpy(), 1e-3) assert np.allclose(w_grad.data.cpu().numpy(), sw_grad.data.cpu().numpy(), 1e-3) assert np.allclose(b_grad.data.cpu().numpy(), sb_grad.data.cpu().numpy(), 1e-3) assert np.allclose(x_grad.data.cpu().numpy(), sx_grad.data.cpu().numpy(), 1e-2) def _test_syncbn_empty_train(self, size=1, half=False): if 'SLURM_NTASKS' not in os.environ or int( os.environ['SLURM_NTASKS']) != 4: print('must run with slurm has 4 processes!\n' 'srun -p test --gres=gpu:4 -n4') return else: print('Running syncbn test') from mmcv.ops import SyncBatchNorm assert size in (1, 2, 4) if not dist.is_initialized(): self.dist_init() rank = dist.get_rank() torch.manual_seed(9) torch.cuda.manual_seed(9) self.x = torch.rand(0, 3, 2, 3).cuda() self.y_bp = torch.rand(0, 3, 2, 3).cuda() if half: self.x = self.x.half() self.y_bp = self.y_bp.half() dist.broadcast(self.x, src=0) dist.broadcast(self.y_bp, src=0) torch.cuda.synchronize() if size == 1: groups = [None, None, None, None] groups[0] = dist.new_group([0]) groups[1] = dist.new_group([1]) groups[2] = dist.new_group([2]) groups[3] = dist.new_group([3]) group = groups[rank] elif size == 2: groups = [None, None, None, None] groups[0] = groups[1] = dist.new_group([0, 1]) groups[2] = groups[3] = dist.new_group([2, 3]) group = groups[rank] elif size == 4: group = dist.group.WORLD syncbn = SyncBatchNorm(3, group=group, stats_mode='N').cuda() syncbn.weight.data[0] = 0.2 syncbn.weight.data[1] = 0.5 syncbn.weight.data[2] = 0.7 syncbn.train() bn = nn.BatchNorm2d(3).cuda() bn.weight.data[0] = 0.2 bn.weight.data[1] = 0.5 bn.weight.data[2] = 0.7 bn.train() sx = self.x[rank * 4:rank * 4 + 4] sx.requires_grad_() sy = syncbn(sx) sy.backward(self.y_bp[rank * 4:rank * 4 + 4]) smean = syncbn.running_mean svar = syncbn.running_var sx_grad = sx.grad sw_grad = syncbn.weight.grad sb_grad = syncbn.bias.grad if size == 1: x = self.x[rank * 4:rank * 4 + 4] y_bp = self.y_bp[rank * 4:rank * 4 + 4] elif size == 2: x = self.x[rank // 2 * 8:rank // 2 * 8 + 8] y_bp = self.y_bp[rank // 2 * 8:rank // 2 * 8 + 8] elif size == 4: x = self.x y_bp = self.y_bp x.requires_grad_() y = bn(x) y.backward(y_bp) if size == 2: y = y[rank % 2 * 4:rank % 2 * 4 + 4] elif size == 4: y = y[rank * 4:rank * 4 + 4] mean = bn.running_mean var = bn.running_var if size == 1: x_grad = x.grad w_grad = bn.weight.grad b_grad = bn.bias.grad elif size == 2: x_grad = x.grad[rank % 2 * 4:rank % 2 * 4 + 4] w_grad = bn.weight.grad / 2 b_grad = bn.bias.grad / 2 elif size == 4: x_grad = x.grad[rank * 4:rank * 4 + 4] w_grad = bn.weight.grad / 4 b_grad = bn.bias.grad / 4 assert np.allclose(mean.data.cpu().numpy(), smean.data.cpu().numpy(), 1e-3) assert np.allclose(var.data.cpu().numpy(), svar.data.cpu().numpy(), 1e-3) assert np.allclose(y.data.cpu().numpy(), sy.data.cpu().numpy(), 1e-3) assert np.allclose(w_grad.data.cpu().numpy(), sw_grad.data.cpu().numpy(), 1e-3) assert np.allclose(b_grad.data.cpu().numpy(), sb_grad.data.cpu().numpy(), 1e-3) assert np.allclose(x_grad.data.cpu().numpy(), sx_grad.data.cpu().numpy(), 1e-2) # 'stats_mode' only allows 'default' and 'N' with pytest.raises(AssertionError): SyncBatchNorm(3, group=group, stats_mode='X') def test_syncbn_1(self): self._test_syncbn_train(size=1) def test_syncbn_2(self): self._test_syncbn_train(size=2) def test_syncbn_4(self): self._test_syncbn_train(size=4) def test_syncbn_1_half(self): self._test_syncbn_train(size=1, half=True) def test_syncbn_2_half(self): self._test_syncbn_train(size=2, half=True) def test_syncbn_4_half(self): self._test_syncbn_train(size=4, half=True) def test_syncbn_empty_1(self): self._test_syncbn_empty_train(size=1) def test_syncbn_empty_2(self): self._test_syncbn_empty_train(size=2) def test_syncbn_empty_4(self): self._test_syncbn_empty_train(size=4) def test_syncbn_empty_1_half(self): self._test_syncbn_empty_train(size=1, half=True) def test_syncbn_empty_2_half(self): self._test_syncbn_empty_train(size=2, half=True) def test_syncbn_empty_4_half(self): self._test_syncbn_empty_train(size=4, half=True)