File size: 9,048 Bytes
d5dfd96 049c65f d5dfd96 eb5a5f6 d5dfd96 eb5a5f6 d5dfd96 6b62ce4 d5dfd96 6b62ce4 d5dfd96 eb5a5f6 d5dfd96 6b62ce4 d5dfd96 6b62ce4 d5dfd96 eb5a5f6 6b62ce4 dca9b6e eb5a5f6 6b62ce4 161df88 eb5a5f6 d5dfd96 eb5a5f6 d5dfd96 eb5a5f6 d5dfd96 6b62ce4 d5dfd96 6b62ce4 d5dfd96 eb5a5f6 d5dfd96 6b62ce4 d5dfd96 6b62ce4 d5dfd96 eb5a5f6 6b62ce4 eb5a5f6 4024f9d dca9b6e eb5a5f6 6b62ce4 161df88 eb5a5f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch.nn as nn
import torch
def quantize(tensor, scale, zero_point, is_asym=False):
if is_asym:
clamp_min, clamp_max = torch.tensor(0.), torch.tensor(255.)
else:
clamp_min, clamp_max = torch.tensor(-128.), torch.tensor(127.)
quant_tensor = torch.clamp(torch.round(tensor/scale + zero_point), clamp_min, clamp_max)
return quant_tensor
def dequantize(tensor, scale, zero_point):
return (tensor - zero_point) * scale
class QuantLinear(nn.Module):
def __init__(self, in_ch, out_ch, quant_param):
super().__init__()
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
self.register_buffer('mul_factor', mul_factor)
self.linear = nn.Linear(in_ch, out_ch)
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
assert quant_param['weight_zp_dtype'] == 'torch.int8', f"Weight Zero-Point dtype should be 'torch.int8', found: {quant_param['weight_zp_dype']}"
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
assert quant_param['input_zp_dtype'] == 'torch.int8', f"Input Zero-Point dtype should be 'torch.int8', found: {quant_param['input_zp_dype']}"
self.register_buffer('weight_scale', weight_scale)
self.register_buffer('weight_zp', weight_zp)
self.register_buffer('input_scale', input_scale)
self.register_buffer('input_zp', input_zp)
# I.e., "fake quantization"
def qdq_forward(self, x):
scaled_x = x * self.mul_factor
weight_zp_uint8 = (self.weight_zp + 128).to(torch.uint8).to(torch.float32)
quant_weight = quantize(self.linear.weight, self.weight_scale, weight_zp_uint8, is_asym=True)
quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
dequantized_weight = dequantize(quant_weight, self.weight_scale, weight_zp_uint8)
dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias)
return out
# Accelerated version
def qop_forward(self, x):
# With an integer linear kernel, if the weight zero point is not zero,
# A correction term must be calculated to correct the output.
# The correction term calculated as follows:
# - sum the input tensor across the dot-product dimentions: (e.g., `torch.sum(quant_input, dim=-1)`)
# - multiply this sum with every weight zero-point (e.g., `torch.sum(quant_input, dim=-1) * self.weight_zp`
# - Subtract from previous output (e.g., `quant_output -= torch.sum(quant_input, dim=-1) * self.weight_zp`)
# - All other code is just to make sure the broadcasting semantics work correctly
quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=False).to(torch.int8)
fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
quant_input = quantize(x, fused_input_scale, self.input_zp, is_asym=False).to(torch.int8)
quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.linear quantizing the output to int8
correction = torch.sum(quant_input, dim=-1, keepdim=True).to(torch.int32) * self.weight_zp.to(torch.int8).view([1]*(quant_input.ndim-1) + [self.weight_zp.nelement()]) # Correct for weight zero-point
quant_output = quant_output - correction
output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()]), 0.0)
output += self.linear.bias
return output
def forward(self, x, qop=False):
if qop:
return self.qop_forward(x)
else:
return self.qdq_forward(x)
class QuantConv2d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, quant_param):
super().__init__()
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
self.register_buffer('mul_factor', mul_factor)
self.conv2d = nn.Conv2d(in_ch, out_ch, kernel_size)
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
assert quant_param['weight_zp_dtype'] == 'torch.int8', f"Weight Zero-Point dtype should be 'torch.int8', found: {quant_param['weight_zp_dype']}"
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
assert quant_param['input_zp_dtype'] == 'torch.int8', f"Input Zero-Point dtype should be 'torch.int8', found: {quant_param['input_zp_dype']}"
self.register_buffer('weight_scale', weight_scale)
self.register_buffer('weight_zp', weight_zp)
self.register_buffer('input_scale', input_scale)
self.register_buffer('input_zp', input_zp)
# I.e., "fake quantization"
def qdq_forward(self, x):
scaled_x = x * self.mul_factor
weight_zp_uint8 = (self.weight_zp + 128).to(torch.uint8).to(torch.float32)
quant_weight = quantize(self.conv2d.weight, self.weight_scale, weight_zp_uint8, is_asym=True)
quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
dequantized_weight = dequantize(quant_weight, self.weight_scale, weight_zp_uint8)
dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias)
return out
# Accelerated version
def qop_forward(self, x):
# With an integer conv2d kernel, if the weight zero point is not zero,
# A correction term must be calculated to correct the output.
# Conceptually, it's identical to the linear case except that it's difficult
# to reduce the input across the dot-product dimension. This leaves us with two obvious options:
# 1. Manually compute the reduction via Im2Col -> `torch.sum`
# 2. Add an extra _output channel_ to the convolution with a kernel made from all ones (e.g., `torch.ones()`)
# In this example, I've used option #2.
# The correction term is then calculated as follows:
# - Add an extra output channel to the weight tensor with all values equal to 1 to calculate the sum (e.g., `torch.cat((quant_weight, torch.ones(shape)), dim=0)`)
# - Extract the sum from the output tensor (e.g., `sum = quant_output[:,-1,:,:]`)
# - multiply this sum with every weight zero-point (e.g., `sum * self.weight_zp`
# - Subtract from previous output (e.g., `quant_output -= sum * self.weight_zp`)
# - All other code is just to make sure the broadcasting semantics work correctly
quant_weight = quantize(self.conv2d.weight, self.weight_scale, self.weight_zp, is_asym=False).to(torch.int8)
b_shape = list(quant_weight.shape) # Used for weight zero-point correction
b_shape[0] = 1 # Used for weight zero-point correction
weight_cat = torch.ones((1,1,1,1)).broadcast_to(b_shape).to(torch.int8) # Used for weight zero-point correction
quant_weight = torch.cat((quant_weight,weight_cat),dim=0).to(torch.int8) # Create extra output channel, used for weight zero-point correction
fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
quant_input = quantize(x, fused_input_scale, self.input_zp, is_asym=False).to(torch.int8)
quant_output = torch.nn.functional.conv2d(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.conv2d quantizing the output to int8
correction = quant_output[:,-1,:,:] * self.weight_zp.to(torch.int8).view([1, self.weight_zp.nelement()] + [1]*(quant_output.ndim-2)) # Correct zero-point for weight
quant_output = quant_output[:,:-1,:,:] - correction
output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1, (self.weight_scale * self.input_scale).nelement()] + [1]*(quant_output.ndim-2)), 0.0)
output += self.conv2d.bias.view([1, self.conv2d.bias.nelement()] + [1]*(quant_output.ndim-2))
return output
def forward(self, x, qop=False):
if qop:
return self.qop_forward(x)
else:
return self.qdq_forward(x)
|