sdxl-quant-int8 / test_quant_linear.py
nickfraser's picture
[test] Fixed shapes to match new `quant_param.json`
673c9f2
raw
history blame
No virus
717 Bytes
import torch
from math_model import QuantLinear
torch.manual_seed(0)
batch_size = 1
out_ch = 128
in_ch = 64
quant_params = {
'smoothquant_mul': torch.rand((in_ch,)),
'smoothquant_mul_shape': (1,in_ch),
'weight_scale': torch.rand((out_ch,)),
'weight_scale_shape': (out_ch,1),
'weight_zp': torch.randint(-255, 0, (out_ch,)),
'weight_zp_shape': (out_ch,1),
'input_scale': torch.rand((1,)),
'input_scale_shape': tuple(),
'input_zp': torch.zeros((1,)),
'input_zp_shape': tuple(),
}
print(quant_params)
l = QuantLinear(in_ch, out_ch, quant_params)
i = torch.rand((batch_size,in_ch))
o_qdq = l(i)
o_qop = l(i, qop=True)
print(o_qdq.shape)
print(o_qop.shape)
print(o_qdq - o_qop)