|
#include <torch/library.h> |
|
|
|
#include "registration.h" |
|
|
|
#include "torch_binding.h" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
|
|
|
|
|
|
|
ops.def( |
|
"paged_attention_v1(" |
|
" Tensor! out, Tensor query, Tensor key_cache," |
|
" Tensor value_cache, int num_kv_heads, float scale," |
|
" Tensor block_tables, Tensor seq_lens, int block_size," |
|
" int max_seq_len, Tensor? alibi_slopes," |
|
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale," |
|
" int tp_rank, int blocksparse_local_blocks," |
|
" int blocksparse_vert_stride, int blocksparse_block_size," |
|
" int blocksparse_head_sliding_step) -> ()"); |
|
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); |
|
|
|
|
|
ops.def( |
|
"paged_attention_v2(" |
|
" Tensor! out, Tensor! exp_sums, Tensor! max_logits," |
|
" Tensor! tmp_out, Tensor query, Tensor key_cache," |
|
" Tensor value_cache, int num_kv_heads, float scale," |
|
" Tensor block_tables, Tensor seq_lens, int block_size," |
|
" int max_seq_len, Tensor? alibi_slopes," |
|
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale," |
|
" int tp_rank, int blocksparse_local_blocks," |
|
" int blocksparse_vert_stride, int blocksparse_block_size," |
|
" int blocksparse_head_sliding_step) -> ()"); |
|
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); |
|
|
|
|
|
ops.def( |
|
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); |
|
ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); |
|
|
|
|
|
ops.def( |
|
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, " |
|
"Tensor block_mapping) -> ()"); |
|
ops.impl("copy_blocks", torch::kCUDA, ©_blocks); |
|
|
|
|
|
ops.def( |
|
"reshape_and_cache(Tensor key, Tensor value," |
|
" Tensor! key_cache, Tensor! value_cache," |
|
" Tensor slot_mapping," |
|
" str kv_cache_dtype," |
|
" Tensor k_scale, Tensor v_scale) -> ()"); |
|
ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); |
|
|
|
|
|
ops.def( |
|
"reshape_and_cache_flash(Tensor key, Tensor value," |
|
" Tensor! key_cache," |
|
" Tensor! value_cache," |
|
" Tensor slot_mapping," |
|
" str kv_cache_dtype," |
|
" Tensor k_scale, Tensor v_scale) -> ()"); |
|
ops.impl("reshape_and_cache_flash", torch::kCUDA, |
|
&reshape_and_cache_flash); |
|
|
|
|
|
ops.def("get_device_attribute(int attribute, int device_id) -> int"); |
|
ops.impl("get_device_attribute", &get_device_attribute); |
|
|
|
|
|
ops.def( |
|
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int"); |
|
ops.impl("get_max_shared_memory_per_block_device_attribute", |
|
&get_max_shared_memory_per_block_device_attribute); |
|
|
|
|
|
ops.def( |
|
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " |
|
"str kv_cache_dtype) -> ()"); |
|
ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); |
|
} |
|
|
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
|
|