RepVGG / RepVGG-main /quantization /repvgg_quantized.py
yuxi-liu-wired's picture
init
0decf42
import torch
import torch.nn as nn
from torch.quantization import QuantStub, DeQuantStub
class RepVGGWholeQuant(nn.Module):
def __init__(self, repvgg_model, quantlayers):
super(RepVGGWholeQuant, self).__init__()
assert quantlayers in ['all', 'exclud_first_and_linear', 'exclud_first_and_last']
self.quantlayers = quantlayers
self.quant = QuantStub()
self.stage0, self.stage1, self.stage2, self.stage3, self.stage4 = repvgg_model.stage0, repvgg_model.stage1, repvgg_model.stage2, repvgg_model.stage3, repvgg_model.stage4
self.gap, self.linear = repvgg_model.gap, repvgg_model.linear
self.dequant = DeQuantStub()
def forward(self, x):
if self.quantlayers == 'all':
x = self.quant(x)
out = self.stage0(x)
else:
out = self.stage0(x)
out = self.quant(out)
out = self.stage1(out)
out = self.stage2(out)
out = self.stage3(out)
if self.quantlayers == 'all':
out = self.stage4(out)
out = self.gap(out).view(out.size(0), -1)
out = self.linear(out)
out = self.dequant(out)
elif self.quantlayers == 'exclud_first_and_linear':
out = self.stage4(out)
out = self.dequant(out)
out = self.gap(out).view(out.size(0), -1)
out = self.linear(out)
else:
out = self.dequant(out)
out = self.stage4(out)
out = self.gap(out).view(out.size(0), -1)
out = self.linear(out)
return out
# From https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html
def fuse_model(self):
for m in self.modules():
if type(m) == nn.Sequential and hasattr(m, 'conv'):
# Note that we moved ReLU from "block.nonlinearity" into "rbr_reparam" (nn.Sequential).
# This makes it more convenient to fuse operators using off-the-shelf APIs.
torch.quantization.fuse_modules(m, ['conv', 'bn', 'relu'], inplace=True)
def _get_qconfig(self):
return torch.quantization.get_default_qat_qconfig('fbgemm')
def prepare_quant(self):
# From https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html
self.fuse_model()
qconfig = self._get_qconfig()
self.qconfig = qconfig
torch.quantization.prepare_qat(self, inplace=True)
def freeze_quant_bn(self):
self.apply(torch.nn.intrinsic.qat.freeze_bn_stats)