|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class FrozenBatchNorm2d(nn.Module): |
|
""" |
|
BatchNorm2d where the batch statistics and the affine parameters |
|
are fixed |
|
""" |
|
def __init__(self, n, epsilon=1e-5): |
|
super(FrozenBatchNorm2d, self).__init__() |
|
self.register_buffer("weight", torch.ones(n)) |
|
self.register_buffer("bias", torch.zeros(n)) |
|
self.register_buffer("running_mean", torch.zeros(n)) |
|
self.register_buffer("running_var", torch.ones(n) - epsilon) |
|
self.epsilon = epsilon |
|
|
|
def forward(self, x): |
|
""" |
|
Refer to Detectron2 (https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py) |
|
""" |
|
if x.requires_grad: |
|
|
|
|
|
scale = self.weight * (self.running_var + self.epsilon).rsqrt() |
|
bias = self.bias - self.running_mean * scale |
|
scale = scale.reshape(1, -1, 1, 1) |
|
bias = bias.reshape(1, -1, 1, 1) |
|
out_dtype = x.dtype |
|
return x * scale.to(out_dtype) + bias.to(out_dtype) |
|
else: |
|
|
|
|
|
return F.batch_norm( |
|
x, |
|
self.running_mean, |
|
self.running_var, |
|
self.weight, |
|
self.bias, |
|
training=False, |
|
eps=self.epsilon, |
|
) |
|
|