GiusFra commited on
Commit
049c65f
1 Parent(s): 742c3ad

Upload math_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. math_model.py +5 -1
math_model.py CHANGED
@@ -6,7 +6,7 @@ def quantize(tensor, scale, zero_point, is_asym=False):
6
  clamp_min, clamp_max = torch.tensor(0.), torch.tensor(255.)
7
  else:
8
  clamp_min, clamp_max = torch.tensor(-128.), torch.tensor(127.)
9
- quant_tensor = torch.clamp(torch.round(tensor/scale), clamp_min, clamp_max) + zero_point
10
  return quant_tensor
11
 
12
  def dequantize(tensor, scale, zero_point):
@@ -30,6 +30,8 @@ class QuantLinear(nn.Module):
30
 
31
  def forward(self, x):
32
  scaled_x = x * self.mul_factor
 
 
33
  quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True)
34
  quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
35
  dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
@@ -54,6 +56,8 @@ class QuantConv2d(nn.Module):
54
 
55
  def forward(self, x):
56
  scaled_x = x * self.mul_factor
 
 
57
  quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True)
58
  quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
59
  dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
 
6
  clamp_min, clamp_max = torch.tensor(0.), torch.tensor(255.)
7
  else:
8
  clamp_min, clamp_max = torch.tensor(-128.), torch.tensor(127.)
9
+ quant_tensor = torch.clamp(torch.round(tensor/scale + zero_point), clamp_min, clamp_max)
10
  return quant_tensor
11
 
12
  def dequantize(tensor, scale, zero_point):
 
30
 
31
  def forward(self, x):
32
  scaled_x = x * self.mul_factor
33
+ # With an integer conv kernel, if the weight zero point is not zero,
34
+ # it is required an extra input channel that is equal to the per-channel zero point of the weights
35
  quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True)
36
  quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
37
  dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
 
56
 
57
  def forward(self, x):
58
  scaled_x = x * self.mul_factor
59
+ # With an integer conv kernel, if the weight zero point is not zero,
60
+ # it is required an extra input channel that is equal to the per-channel zero point of the weights
61
  quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True)
62
  quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
63
  dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)