File size: 1,039 Bytes
eb5a5f6
4024f9d
eb5a5f6
 
 
 
 
 
 
 
4024f9d
 
 
eb5a5f6
 
673c9f2
4024f9d
eb5a5f6
6b62ce4
eb5a5f6
6b62ce4
4024f9d
673c9f2
eb5a5f6
673c9f2
6b62ce4
eb5a5f6
 
 
 
4024f9d
 
 
 
eb5a5f6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn
from math_model import QuantLinear

torch.manual_seed(0)

batch_size = 1
out_ch = 128
in_ch = 64

i = 2*torch.rand((batch_size,in_ch)) - 1.
l = nn.Linear(in_ch, out_ch, bias=True)

quant_params = {
    'smoothquant_mul': torch.rand((in_ch,)),
    'smoothquant_mul_shape': (1,in_ch),
    'weight_scale': torch.max(torch.abs(l.weight), dim=1).values / 128.,
    'weight_scale_shape': (out_ch,1),
    'weight_zp': torch.clamp(torch.round((torch.mean((l.weight), dim=1)) * (128 / torch.max(torch.abs(l.weight), dim=1).values)), -128, 127),
    'weight_zp_shape': (out_ch,1),
    'weight_zp_dtype': 'torch.int8',
    'input_scale': torch.max(torch.abs(i)) / 128.,
    'input_scale_shape': tuple(),
    'input_zp': torch.zeros((1,)),
    'input_zp_shape': tuple(),
    'input_zp_dtype': 'torch.int8',
}

print(quant_params)

ql = QuantLinear(in_ch, out_ch, quant_params)
ql.linear.load_state_dict(l.state_dict())
o_qdq = ql(i)
o_qop = ql(i, qop=True)
print(o_qdq.shape)
print(o_qop.shape)
print(o_qdq - o_qop)