quantization-eetq / cutlass_kernels /fpA_intB_gemm_wrapper.h
danieldk's picture
danieldk HF staff
Import EETQ kernels
1dc29e9
raw
history blame
1.08 kB
#include <torch/all.h>
#include <vector>
#define SMALL_M_FAST_PATH 4
std::vector<torch::Tensor>
symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
at::ScalarType quant_type,
bool return_unprocessed_quantized_tensor);
torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight,
bool is_int4);
torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
torch::Tensor const &weight,
torch::Tensor const &scale);
torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
torch::Tensor const &weight,
torch::Tensor const &scale,
torch::Tensor &output,
const int64_t m,
const int64_t n,
const int64_t k);