nickfraser commited on
Commit
9ab1060
1 Parent(s): 673c9f2

Fix: set `keepdim=True`

Browse files
Files changed (1) hide show
  1. math_model.py +1 -1
math_model.py CHANGED
@@ -51,7 +51,7 @@ class QuantLinear(nn.Module):
51
  quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True).to(torch.uint8)
52
  quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False).to(torch.int8)
53
  quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.linear quantizing the output to int8
54
- correction = torch.sum(quant_input, dim=-1).to(torch.int32).unsqueeze(-1) * (-self.weight_zp).to(torch.uint8).view([1]*(quant_input.ndim-1) + [self.weight_zp.nelement()]) # Correct for weight zero-point
55
  quant_output = quant_output + correction
56
  output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()]), 0.0)
57
  output += self.linear.bias
 
51
  quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True).to(torch.uint8)
52
  quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False).to(torch.int8)
53
  quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.linear quantizing the output to int8
54
+ correction = torch.sum(quant_input, dim=-1, keepdim=True).to(torch.int32) * (-self.weight_zp).to(torch.uint8).view([1]*(quant_input.ndim-1) + [self.weight_zp.nelement()]) # Correct for weight zero-point
55
  quant_output = quant_output + correction
56
  output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()]), 0.0)
57
  output += self.linear.bias