AiOS / mmcv /tests /test_ops /test_syncbn.py
ttxskk
update
d7e58f0
raw
history blame
9.69 kB
# Copyright (c) OpenMMLab. All rights reserved.
import os
import platform
import numpy as np
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
if platform.system() == 'Windows':
import regex as re
else:
import re
class TestSyncBN:
def dist_init(self):
rank = int(os.environ['SLURM_PROCID'])
world_size = int(os.environ['SLURM_NTASKS'])
local_rank = int(os.environ['SLURM_LOCALID'])
node_list = str(os.environ['SLURM_NODELIST'])
node_parts = re.findall('[0-9]+', node_list)
os.environ['MASTER_ADDR'] = (f'{node_parts[1]}.{node_parts[2]}' +
f'.{node_parts[3]}.{node_parts[4]}')
os.environ['MASTER_PORT'] = '12341'
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
dist.init_process_group('nccl')
torch.cuda.set_device(local_rank)
def _test_syncbn_train(self, size=1, half=False):
if 'SLURM_NTASKS' not in os.environ or int(
os.environ['SLURM_NTASKS']) != 4:
print('must run with slurm has 4 processes!\n'
'srun -p test --gres=gpu:4 -n4')
return
else:
print('Running syncbn test')
from mmcv.ops import SyncBatchNorm
assert size in (1, 2, 4)
if not dist.is_initialized():
self.dist_init()
rank = dist.get_rank()
torch.manual_seed(9)
torch.cuda.manual_seed(9)
self.x = torch.rand(16, 3, 2, 3).cuda()
self.y_bp = torch.rand(16, 3, 2, 3).cuda()
if half:
self.x = self.x.half()
self.y_bp = self.y_bp.half()
dist.broadcast(self.x, src=0)
dist.broadcast(self.y_bp, src=0)
torch.cuda.synchronize()
if size == 1:
groups = [None, None, None, None]
groups[0] = dist.new_group([0])
groups[1] = dist.new_group([1])
groups[2] = dist.new_group([2])
groups[3] = dist.new_group([3])
group = groups[rank]
elif size == 2:
groups = [None, None, None, None]
groups[0] = groups[1] = dist.new_group([0, 1])
groups[2] = groups[3] = dist.new_group([2, 3])
group = groups[rank]
elif size == 4:
group = dist.group.WORLD
syncbn = SyncBatchNorm(3, group=group).cuda()
syncbn.weight.data[0] = 0.2
syncbn.weight.data[1] = 0.5
syncbn.weight.data[2] = 0.7
syncbn.train()
bn = nn.BatchNorm2d(3).cuda()
bn.weight.data[0] = 0.2
bn.weight.data[1] = 0.5
bn.weight.data[2] = 0.7
bn.train()
sx = self.x[rank * 4:rank * 4 + 4]
sx.requires_grad_()
sy = syncbn(sx)
sy.backward(self.y_bp[rank * 4:rank * 4 + 4])
smean = syncbn.running_mean
svar = syncbn.running_var
sx_grad = sx.grad
sw_grad = syncbn.weight.grad
sb_grad = syncbn.bias.grad
if size == 1:
x = self.x[rank * 4:rank * 4 + 4]
y_bp = self.y_bp[rank * 4:rank * 4 + 4]
elif size == 2:
x = self.x[rank // 2 * 8:rank // 2 * 8 + 8]
y_bp = self.y_bp[rank // 2 * 8:rank // 2 * 8 + 8]
elif size == 4:
x = self.x
y_bp = self.y_bp
x.requires_grad_()
y = bn(x)
y.backward(y_bp)
if size == 2:
y = y[rank % 2 * 4:rank % 2 * 4 + 4]
elif size == 4:
y = y[rank * 4:rank * 4 + 4]
mean = bn.running_mean
var = bn.running_var
if size == 1:
x_grad = x.grad
w_grad = bn.weight.grad
b_grad = bn.bias.grad
elif size == 2:
x_grad = x.grad[rank % 2 * 4:rank % 2 * 4 + 4]
w_grad = bn.weight.grad / 2
b_grad = bn.bias.grad / 2
elif size == 4:
x_grad = x.grad[rank * 4:rank * 4 + 4]
w_grad = bn.weight.grad / 4
b_grad = bn.bias.grad / 4
assert np.allclose(mean.data.cpu().numpy(),
smean.data.cpu().numpy(), 1e-3)
assert np.allclose(var.data.cpu().numpy(),
svar.data.cpu().numpy(), 1e-3)
assert np.allclose(y.data.cpu().numpy(), sy.data.cpu().numpy(), 1e-3)
assert np.allclose(w_grad.data.cpu().numpy(),
sw_grad.data.cpu().numpy(), 1e-3)
assert np.allclose(b_grad.data.cpu().numpy(),
sb_grad.data.cpu().numpy(), 1e-3)
assert np.allclose(x_grad.data.cpu().numpy(),
sx_grad.data.cpu().numpy(), 1e-2)
def _test_syncbn_empty_train(self, size=1, half=False):
if 'SLURM_NTASKS' not in os.environ or int(
os.environ['SLURM_NTASKS']) != 4:
print('must run with slurm has 4 processes!\n'
'srun -p test --gres=gpu:4 -n4')
return
else:
print('Running syncbn test')
from mmcv.ops import SyncBatchNorm
assert size in (1, 2, 4)
if not dist.is_initialized():
self.dist_init()
rank = dist.get_rank()
torch.manual_seed(9)
torch.cuda.manual_seed(9)
self.x = torch.rand(0, 3, 2, 3).cuda()
self.y_bp = torch.rand(0, 3, 2, 3).cuda()
if half:
self.x = self.x.half()
self.y_bp = self.y_bp.half()
dist.broadcast(self.x, src=0)
dist.broadcast(self.y_bp, src=0)
torch.cuda.synchronize()
if size == 1:
groups = [None, None, None, None]
groups[0] = dist.new_group([0])
groups[1] = dist.new_group([1])
groups[2] = dist.new_group([2])
groups[3] = dist.new_group([3])
group = groups[rank]
elif size == 2:
groups = [None, None, None, None]
groups[0] = groups[1] = dist.new_group([0, 1])
groups[2] = groups[3] = dist.new_group([2, 3])
group = groups[rank]
elif size == 4:
group = dist.group.WORLD
syncbn = SyncBatchNorm(3, group=group, stats_mode='N').cuda()
syncbn.weight.data[0] = 0.2
syncbn.weight.data[1] = 0.5
syncbn.weight.data[2] = 0.7
syncbn.train()
bn = nn.BatchNorm2d(3).cuda()
bn.weight.data[0] = 0.2
bn.weight.data[1] = 0.5
bn.weight.data[2] = 0.7
bn.train()
sx = self.x[rank * 4:rank * 4 + 4]
sx.requires_grad_()
sy = syncbn(sx)
sy.backward(self.y_bp[rank * 4:rank * 4 + 4])
smean = syncbn.running_mean
svar = syncbn.running_var
sx_grad = sx.grad
sw_grad = syncbn.weight.grad
sb_grad = syncbn.bias.grad
if size == 1:
x = self.x[rank * 4:rank * 4 + 4]
y_bp = self.y_bp[rank * 4:rank * 4 + 4]
elif size == 2:
x = self.x[rank // 2 * 8:rank // 2 * 8 + 8]
y_bp = self.y_bp[rank // 2 * 8:rank // 2 * 8 + 8]
elif size == 4:
x = self.x
y_bp = self.y_bp
x.requires_grad_()
y = bn(x)
y.backward(y_bp)
if size == 2:
y = y[rank % 2 * 4:rank % 2 * 4 + 4]
elif size == 4:
y = y[rank * 4:rank * 4 + 4]
mean = bn.running_mean
var = bn.running_var
if size == 1:
x_grad = x.grad
w_grad = bn.weight.grad
b_grad = bn.bias.grad
elif size == 2:
x_grad = x.grad[rank % 2 * 4:rank % 2 * 4 + 4]
w_grad = bn.weight.grad / 2
b_grad = bn.bias.grad / 2
elif size == 4:
x_grad = x.grad[rank * 4:rank * 4 + 4]
w_grad = bn.weight.grad / 4
b_grad = bn.bias.grad / 4
assert np.allclose(mean.data.cpu().numpy(),
smean.data.cpu().numpy(), 1e-3)
assert np.allclose(var.data.cpu().numpy(),
svar.data.cpu().numpy(), 1e-3)
assert np.allclose(y.data.cpu().numpy(), sy.data.cpu().numpy(), 1e-3)
assert np.allclose(w_grad.data.cpu().numpy(),
sw_grad.data.cpu().numpy(), 1e-3)
assert np.allclose(b_grad.data.cpu().numpy(),
sb_grad.data.cpu().numpy(), 1e-3)
assert np.allclose(x_grad.data.cpu().numpy(),
sx_grad.data.cpu().numpy(), 1e-2)
# 'stats_mode' only allows 'default' and 'N'
with pytest.raises(AssertionError):
SyncBatchNorm(3, group=group, stats_mode='X')
def test_syncbn_1(self):
self._test_syncbn_train(size=1)
def test_syncbn_2(self):
self._test_syncbn_train(size=2)
def test_syncbn_4(self):
self._test_syncbn_train(size=4)
def test_syncbn_1_half(self):
self._test_syncbn_train(size=1, half=True)
def test_syncbn_2_half(self):
self._test_syncbn_train(size=2, half=True)
def test_syncbn_4_half(self):
self._test_syncbn_train(size=4, half=True)
def test_syncbn_empty_1(self):
self._test_syncbn_empty_train(size=1)
def test_syncbn_empty_2(self):
self._test_syncbn_empty_train(size=2)
def test_syncbn_empty_4(self):
self._test_syncbn_empty_train(size=4)
def test_syncbn_empty_1_half(self):
self._test_syncbn_empty_train(size=1, half=True)
def test_syncbn_empty_2_half(self):
self._test_syncbn_empty_train(size=2, half=True)
def test_syncbn_empty_4_half(self):
self._test_syncbn_empty_train(size=4, half=True)