|
import torch.nn as nn |
|
import torch |
|
|
|
def quantize(tensor, scale, zero_point, is_asym=False): |
|
if is_asym: |
|
clamp_min, clamp_max = torch.tensor(0.), torch.tensor(255.) |
|
else: |
|
clamp_min, clamp_max = torch.tensor(-128.), torch.tensor(127.) |
|
quant_tensor = torch.clamp(torch.round(tensor/scale + zero_point), clamp_min, clamp_max) |
|
return quant_tensor |
|
|
|
def dequantize(tensor, scale, zero_point): |
|
return (tensor - zero_point) * scale |
|
|
|
|
|
class QuantLinear(nn.Module): |
|
def __init__(self, in_ch, out_ch, quant_param): |
|
super().__init__() |
|
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape']) |
|
self.register_buffer('mul_factor', mul_factor) |
|
self.linear = nn.Linear(in_ch, out_ch) |
|
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape']) |
|
weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape']) |
|
assert quant_param['weight_zp_dtype'] == 'torch.int8', f"Weight Zero-Point dtype should be 'torch.int8', found: {quant_param['weight_zp_dype']}" |
|
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape']) |
|
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape']) |
|
assert quant_param['input_zp_dtype'] == 'torch.int8', f"Input Zero-Point dtype should be 'torch.int8', found: {quant_param['input_zp_dype']}" |
|
self.register_buffer('weight_scale', weight_scale) |
|
self.register_buffer('weight_zp', weight_zp) |
|
self.register_buffer('input_scale', input_scale) |
|
self.register_buffer('input_zp', input_zp) |
|
|
|
|
|
def qdq_forward(self, x): |
|
scaled_x = x * self.mul_factor |
|
weight_zp_uint8 = (self.weight_zp + 128).to(torch.uint8).to(torch.float32) |
|
quant_weight = quantize(self.linear.weight, self.weight_scale, weight_zp_uint8, is_asym=True) |
|
quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False) |
|
dequantized_weight = dequantize(quant_weight, self.weight_scale, weight_zp_uint8) |
|
dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp) |
|
out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias) |
|
return out |
|
|
|
|
|
def qop_forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=False).to(torch.int8) |
|
fused_input_scale = self.input_scale / self.mul_factor |
|
quant_input = quantize(x, fused_input_scale, self.input_zp, is_asym=False).to(torch.int8) |
|
quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) |
|
correction = torch.sum(quant_input, dim=-1, keepdim=True).to(torch.int32) * self.weight_zp.to(torch.int8).view([1]*(quant_input.ndim-1) + [self.weight_zp.nelement()]) |
|
quant_output = quant_output - correction |
|
output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()]), 0.0) |
|
output += self.linear.bias |
|
return output |
|
|
|
def forward(self, x, qop=False): |
|
if qop: |
|
return self.qop_forward(x) |
|
else: |
|
return self.qdq_forward(x) |
|
|
|
class QuantConv2d(nn.Module): |
|
def __init__(self, in_ch, out_ch, kernel_size, quant_param): |
|
super().__init__() |
|
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape']) |
|
self.register_buffer('mul_factor', mul_factor) |
|
self.conv2d = nn.Conv2d(in_ch, out_ch, kernel_size) |
|
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape']) |
|
weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape']) |
|
assert quant_param['weight_zp_dtype'] == 'torch.int8', f"Weight Zero-Point dtype should be 'torch.int8', found: {quant_param['weight_zp_dype']}" |
|
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape']) |
|
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape']) |
|
assert quant_param['input_zp_dtype'] == 'torch.int8', f"Input Zero-Point dtype should be 'torch.int8', found: {quant_param['input_zp_dype']}" |
|
self.register_buffer('weight_scale', weight_scale) |
|
self.register_buffer('weight_zp', weight_zp) |
|
self.register_buffer('input_scale', input_scale) |
|
self.register_buffer('input_zp', input_zp) |
|
|
|
|
|
def qdq_forward(self, x): |
|
scaled_x = x * self.mul_factor |
|
weight_zp_uint8 = (self.weight_zp + 128).to(torch.uint8).to(torch.float32) |
|
quant_weight = quantize(self.conv2d.weight, self.weight_scale, weight_zp_uint8, is_asym=True) |
|
quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False) |
|
dequantized_weight = dequantize(quant_weight, self.weight_scale, weight_zp_uint8) |
|
dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp) |
|
out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias) |
|
return out |
|
|
|
|
|
def qop_forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
quant_weight = quantize(self.conv2d.weight, self.weight_scale, self.weight_zp, is_asym=False).to(torch.int8) |
|
b_shape = list(quant_weight.shape) |
|
b_shape[0] = 1 |
|
weight_cat = torch.ones((1,1,1,1)).broadcast_to(b_shape).to(torch.int8) |
|
quant_weight = torch.cat((quant_weight,weight_cat),dim=0).to(torch.int8) |
|
fused_input_scale = self.input_scale / self.mul_factor |
|
quant_input = quantize(x, fused_input_scale, self.input_zp, is_asym=False).to(torch.int8) |
|
quant_output = torch.nn.functional.conv2d(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) |
|
correction = quant_output[:,-1,:,:] * self.weight_zp.to(torch.int8).view([1, self.weight_zp.nelement()] + [1]*(quant_output.ndim-2)) |
|
quant_output = quant_output[:,:-1,:,:] - correction |
|
output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1, (self.weight_scale * self.input_scale).nelement()] + [1]*(quant_output.ndim-2)), 0.0) |
|
output += self.conv2d.bias.view([1, self.conv2d.bias.nelement()] + [1]*(quant_output.ndim-2)) |
|
return output |
|
|
|
def forward(self, x, qop=False): |
|
if qop: |
|
return self.qop_forward(x) |
|
else: |
|
return self.qdq_forward(x) |
|
|