File size: 1,071 Bytes
1dc29e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
#pragma once
#include <vector>
#include <torch/torch.h>
std::vector<torch::Tensor>
symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
at::ScalarType quant_type,
bool return_unprocessed_quantized_tensor);
torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight,
bool is_int4);
torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
torch::Tensor const&weight,
torch::Tensor const &scale);
torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
torch::Tensor const &weight,
torch::Tensor const &scale,
torch::Tensor &output,
const int64_t m,
const int64_t n,
const int64_t k);
|