File size: 1,071 Bytes
1dc29e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#pragma once

#include <vector>

#include <torch/torch.h>

std::vector<torch::Tensor>
symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
                                       at::ScalarType quant_type,
                                       bool return_unprocessed_quantized_tensor);

torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight,
                                      bool is_int4);

torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
                                       torch::Tensor const&weight,
                                       torch::Tensor const &scale);

torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
                                        torch::Tensor const &weight,
                                        torch::Tensor const &scale,
                                        torch::Tensor &output,
                                        const int64_t m,
                                        const int64_t n,
                                        const int64_t k);