GiusFra commited on
Commit
d5dfd96
1 Parent(s): 88730c2

Upload math_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. math_model.py +62 -0
math_model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ def quantize(tensor, scale, zero_point, is_asym=False):
5
+ if is_asym:
6
+ clamp_min, clamp_max = torch.tensor(0.), torch.tensor(255.)
7
+ else:
8
+ clamp_min, clamp_max = torch.tensor(-128.), torch.tensor(127.)
9
+ quant_tensor = torch.clamp(torch.round(tensor/scale), clamp_min, clamp_max) + zero_point
10
+ return quant_tensor
11
+
12
+ def dequantize(tensor, scale, zero_point):
13
+ return (tensor - zero_point) * scale
14
+
15
+
16
+ class QuantLinear(nn.Module):
17
+ def __init__(self, quant_param):
18
+ super().__init__()
19
+ mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
20
+ self.register_buffer('mul_factor', mul_factor)
21
+ self.linear = nn.Linear(128, 128)
22
+ weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
23
+ weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
24
+ input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
25
+ input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
26
+ self.register_buffer('weight_scale', weight_scale)
27
+ self.register_buffer('weight_zp', weight_zp)
28
+ self.register_buffer('input_scale', input_scale)
29
+ self.register_buffer('input_zp', input_zp)
30
+
31
+ def forward(self, x):
32
+ scaled_x = x * self.mul_factor
33
+ quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True)
34
+ quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
35
+ dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
36
+ dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
37
+ out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias)
38
+ return out
39
+
40
+ class QuantConv2d(nn.Module):
41
+ def __init__(self, quant_param):
42
+ super().__init__()
43
+ mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
44
+ self.register_buffer('mul_factor', mul_factor)
45
+ self.conv2d = nn.Conv2d(128, 128, 3)
46
+ weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
47
+ weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
48
+ input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
49
+ input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
50
+ self.register_buffer('weight_scale', weight_scale)
51
+ self.register_buffer('weight_zp', weight_zp)
52
+ self.register_buffer('input_scale', input_scale)
53
+ self.register_buffer('input_zp', input_zp)
54
+
55
+ def forward(self, x):
56
+ scaled_x = x * self.mul_factor
57
+ quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True)
58
+ quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
59
+ dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
60
+ dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
61
+ out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias)
62
+ return out