|
#include <torch/library.h> |
|
|
|
#include "registration.h" |
|
#include "torch_binding.h" |
|
|
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
|
|
|
|
|
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); |
|
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); |
|
|
|
|
|
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); |
|
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); |
|
|
|
|
|
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); |
|
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); |
|
|
|
|
|
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); |
|
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul); |
|
|
|
|
|
ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); |
|
ops.impl("gelu_new", torch::kCUDA, &gelu_new); |
|
|
|
|
|
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); |
|
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast); |
|
} |
|
|
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
|
|