|
import torch |
|
import torch.nn as nn |
|
from torch.quantization import QuantStub, DeQuantStub |
|
|
|
class RepVGGWholeQuant(nn.Module): |
|
|
|
def __init__(self, repvgg_model, quantlayers): |
|
super(RepVGGWholeQuant, self).__init__() |
|
assert quantlayers in ['all', 'exclud_first_and_linear', 'exclud_first_and_last'] |
|
self.quantlayers = quantlayers |
|
self.quant = QuantStub() |
|
self.stage0, self.stage1, self.stage2, self.stage3, self.stage4 = repvgg_model.stage0, repvgg_model.stage1, repvgg_model.stage2, repvgg_model.stage3, repvgg_model.stage4 |
|
self.gap, self.linear = repvgg_model.gap, repvgg_model.linear |
|
self.dequant = DeQuantStub() |
|
|
|
|
|
def forward(self, x): |
|
if self.quantlayers == 'all': |
|
x = self.quant(x) |
|
out = self.stage0(x) |
|
else: |
|
out = self.stage0(x) |
|
out = self.quant(out) |
|
out = self.stage1(out) |
|
out = self.stage2(out) |
|
out = self.stage3(out) |
|
if self.quantlayers == 'all': |
|
out = self.stage4(out) |
|
out = self.gap(out).view(out.size(0), -1) |
|
out = self.linear(out) |
|
out = self.dequant(out) |
|
elif self.quantlayers == 'exclud_first_and_linear': |
|
out = self.stage4(out) |
|
out = self.dequant(out) |
|
out = self.gap(out).view(out.size(0), -1) |
|
out = self.linear(out) |
|
else: |
|
out = self.dequant(out) |
|
out = self.stage4(out) |
|
out = self.gap(out).view(out.size(0), -1) |
|
out = self.linear(out) |
|
return out |
|
|
|
|
|
def fuse_model(self): |
|
for m in self.modules(): |
|
if type(m) == nn.Sequential and hasattr(m, 'conv'): |
|
|
|
|
|
torch.quantization.fuse_modules(m, ['conv', 'bn', 'relu'], inplace=True) |
|
|
|
def _get_qconfig(self): |
|
return torch.quantization.get_default_qat_qconfig('fbgemm') |
|
|
|
def prepare_quant(self): |
|
|
|
self.fuse_model() |
|
qconfig = self._get_qconfig() |
|
self.qconfig = qconfig |
|
torch.quantization.prepare_qat(self, inplace=True) |
|
|
|
def freeze_quant_bn(self): |
|
self.apply(torch.nn.intrinsic.qat.freeze_bn_stats) |