Add activation kernels
Browse files- README.md +3 -3
- activation/activation_kernels.cu +204 -0
- activation/cuda_compat.h +49 -0
- activation/dispatch_utils.h +35 -0
- build.toml +18 -0
- ext-torch/__init__.py +32 -0
- ext-torch/registration.h +27 -0
- ext-torch/torch_binding.cpp +33 -0
- ext-torch/torch_binding.h +18 -0
README.md
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
1 |
+
## Activation
|
2 |
+
|
3 |
+
Activation kernels from [vLLM](https://github.com/vllm-project/vllm/blob/main/csrc/activation_kernels.cu).
|
activation/activation_kernels.cu
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/cuda/CUDAContext.h>
|
2 |
+
#include <torch/all.h>
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
|
5 |
+
#include <cmath>
|
6 |
+
|
7 |
+
#include "cuda_compat.h"
|
8 |
+
#include "dispatch_utils.h"
|
9 |
+
|
10 |
+
namespace vllm {
|
11 |
+
|
12 |
+
// Activation and gating kernel template.
|
13 |
+
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
14 |
+
__global__ void act_and_mul_kernel(
|
15 |
+
scalar_t* __restrict__ out, // [..., d]
|
16 |
+
const scalar_t* __restrict__ input, // [..., 2, d]
|
17 |
+
const int d) {
|
18 |
+
const int64_t token_idx = blockIdx.x;
|
19 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
20 |
+
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
21 |
+
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
22 |
+
out[token_idx * d + idx] = ACT_FN(x) * y;
|
23 |
+
}
|
24 |
+
}
|
25 |
+
|
26 |
+
template <typename T>
|
27 |
+
__device__ __forceinline__ T silu_kernel(const T& x) {
|
28 |
+
// x * sigmoid(x)
|
29 |
+
return (T)(((float)x) / (1.0f + expf((float)-x)));
|
30 |
+
}
|
31 |
+
|
32 |
+
template <typename T>
|
33 |
+
__device__ __forceinline__ T gelu_kernel(const T& x) {
|
34 |
+
// Equivalent to PyTorch GELU with 'none' approximation.
|
35 |
+
// Refer to:
|
36 |
+
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
|
37 |
+
const float f = (float)x;
|
38 |
+
constexpr float ALPHA = M_SQRT1_2;
|
39 |
+
return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
|
40 |
+
}
|
41 |
+
|
42 |
+
template <typename T>
|
43 |
+
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
44 |
+
// Equivalent to PyTorch GELU with 'tanh' approximation.
|
45 |
+
// Refer to:
|
46 |
+
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
|
47 |
+
const float f = (float)x;
|
48 |
+
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
|
49 |
+
constexpr float KAPPA = 0.044715;
|
50 |
+
float x_cube = f * f * f;
|
51 |
+
float inner = BETA * (f + KAPPA * x_cube);
|
52 |
+
return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
|
53 |
+
}
|
54 |
+
|
55 |
+
} // namespace vllm
|
56 |
+
|
57 |
+
// Launch activation and gating kernel.
|
58 |
+
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
59 |
+
int d = input.size(-1) / 2; \
|
60 |
+
int64_t num_tokens = input.numel() / input.size(-1); \
|
61 |
+
dim3 grid(num_tokens); \
|
62 |
+
dim3 block(std::min(d, 1024)); \
|
63 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
64 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
65 |
+
VLLM_DISPATCH_FLOATING_TYPES( \
|
66 |
+
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
67 |
+
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
|
68 |
+
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
69 |
+
input.data_ptr<scalar_t>(), d); \
|
70 |
+
});
|
71 |
+
|
72 |
+
void silu_and_mul(torch::Tensor& out, // [..., d]
|
73 |
+
torch::Tensor& input) // [..., 2 * d]
|
74 |
+
{
|
75 |
+
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
76 |
+
}
|
77 |
+
|
78 |
+
void gelu_and_mul(torch::Tensor& out, // [..., d]
|
79 |
+
torch::Tensor& input) // [..., 2 * d]
|
80 |
+
{
|
81 |
+
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
82 |
+
}
|
83 |
+
|
84 |
+
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
85 |
+
torch::Tensor& input) // [..., 2 * d]
|
86 |
+
{
|
87 |
+
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
88 |
+
}
|
89 |
+
|
90 |
+
namespace vllm {
|
91 |
+
|
92 |
+
template <typename T>
|
93 |
+
__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
|
94 |
+
const float f = (float)x;
|
95 |
+
return (T)(f > threshold ? f : 0.0f);
|
96 |
+
}
|
97 |
+
|
98 |
+
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
|
99 |
+
__global__ void act_and_mul_kernel_with_param(
|
100 |
+
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
|
101 |
+
const float param) {
|
102 |
+
const int64_t token_idx = blockIdx.x;
|
103 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
104 |
+
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
105 |
+
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
106 |
+
out[token_idx * d + idx] = ACT_FN(x, param) * y;
|
107 |
+
}
|
108 |
+
}
|
109 |
+
|
110 |
+
} // namespace vllm
|
111 |
+
|
112 |
+
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
|
113 |
+
int d = input.size(-1) / 2; \
|
114 |
+
int64_t num_tokens = input.numel() / input.size(-1); \
|
115 |
+
dim3 grid(num_tokens); \
|
116 |
+
dim3 block(std::min(d, 1024)); \
|
117 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
118 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
119 |
+
VLLM_DISPATCH_FLOATING_TYPES( \
|
120 |
+
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
|
121 |
+
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
|
122 |
+
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
123 |
+
input.data_ptr<scalar_t>(), d, \
|
124 |
+
PARAM); \
|
125 |
+
});
|
126 |
+
|
127 |
+
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
|
128 |
+
torch::Tensor& input, // [..., 2 * d]
|
129 |
+
double threshold) {
|
130 |
+
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
|
131 |
+
}
|
132 |
+
namespace vllm {
|
133 |
+
|
134 |
+
// Element-wise activation kernel template.
|
135 |
+
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
136 |
+
__global__ void activation_kernel(
|
137 |
+
scalar_t* __restrict__ out, // [..., d]
|
138 |
+
const scalar_t* __restrict__ input, // [..., d]
|
139 |
+
const int d) {
|
140 |
+
const int64_t token_idx = blockIdx.x;
|
141 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
142 |
+
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
143 |
+
out[token_idx * d + idx] = ACT_FN(x);
|
144 |
+
}
|
145 |
+
}
|
146 |
+
|
147 |
+
} // namespace vllm
|
148 |
+
|
149 |
+
// Launch element-wise activation kernel.
|
150 |
+
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
151 |
+
int d = input.size(-1); \
|
152 |
+
int64_t num_tokens = input.numel() / d; \
|
153 |
+
dim3 grid(num_tokens); \
|
154 |
+
dim3 block(std::min(d, 1024)); \
|
155 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
156 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
157 |
+
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
|
158 |
+
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
|
159 |
+
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
160 |
+
input.data_ptr<scalar_t>(), d); \
|
161 |
+
});
|
162 |
+
|
163 |
+
namespace vllm {
|
164 |
+
|
165 |
+
template <typename T>
|
166 |
+
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
167 |
+
const float x3 = (float)(x * x * x);
|
168 |
+
const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
|
169 |
+
return ((T)0.5) * x * (((T)1.0) + t);
|
170 |
+
}
|
171 |
+
|
172 |
+
template <typename T>
|
173 |
+
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
174 |
+
const float f = (float)x;
|
175 |
+
const T t =
|
176 |
+
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
|
177 |
+
return ((T)0.5) * x * (((T)1.0) + t);
|
178 |
+
}
|
179 |
+
|
180 |
+
template <typename T>
|
181 |
+
__device__ __forceinline__ T gelu_quick_kernel(const T& x) {
|
182 |
+
// x * sigmoid(1.702 * x)
|
183 |
+
return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x)));
|
184 |
+
}
|
185 |
+
|
186 |
+
} // namespace vllm
|
187 |
+
|
188 |
+
void gelu_new(torch::Tensor& out, // [..., d]
|
189 |
+
torch::Tensor& input) // [..., d]
|
190 |
+
{
|
191 |
+
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
192 |
+
}
|
193 |
+
|
194 |
+
void gelu_fast(torch::Tensor& out, // [..., d]
|
195 |
+
torch::Tensor& input) // [..., d]
|
196 |
+
{
|
197 |
+
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
198 |
+
}
|
199 |
+
|
200 |
+
void gelu_quick(torch::Tensor& out, // [..., d]
|
201 |
+
torch::Tensor& input) // [..., d]
|
202 |
+
{
|
203 |
+
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
|
204 |
+
}
|
activation/cuda_compat.h
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#ifdef USE_ROCM
|
4 |
+
#include <hip/hip_runtime.h>
|
5 |
+
#endif
|
6 |
+
|
7 |
+
#ifndef USE_ROCM
|
8 |
+
#define WARP_SIZE 32
|
9 |
+
#else
|
10 |
+
#define WARP_SIZE warpSize
|
11 |
+
#endif
|
12 |
+
|
13 |
+
#ifndef USE_ROCM
|
14 |
+
#define VLLM_LDG(arg) __ldg(arg)
|
15 |
+
#else
|
16 |
+
#define VLLM_LDG(arg) *(arg)
|
17 |
+
#endif
|
18 |
+
|
19 |
+
#ifndef USE_ROCM
|
20 |
+
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
|
21 |
+
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
22 |
+
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
23 |
+
__shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
|
24 |
+
#else
|
25 |
+
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
26 |
+
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
27 |
+
__shfl_xor(var, lane_mask, width)
|
28 |
+
#endif
|
29 |
+
|
30 |
+
#ifndef USE_ROCM
|
31 |
+
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
32 |
+
#else
|
33 |
+
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
34 |
+
#endif
|
35 |
+
|
36 |
+
#ifndef USE_ROCM
|
37 |
+
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
|
38 |
+
__shfl_down_sync(uint32_t(-1), var, lane_delta)
|
39 |
+
#else
|
40 |
+
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
|
41 |
+
#endif
|
42 |
+
|
43 |
+
#ifndef USE_ROCM
|
44 |
+
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
45 |
+
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
46 |
+
#else
|
47 |
+
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
48 |
+
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
49 |
+
#endif
|
activation/dispatch_utils.h
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
4 |
+
*/
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include <torch/all.h>
|
8 |
+
|
9 |
+
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
10 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
11 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
12 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
13 |
+
|
14 |
+
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
15 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
16 |
+
|
17 |
+
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
18 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
19 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
20 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
21 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
22 |
+
|
23 |
+
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
24 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
25 |
+
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
26 |
+
|
27 |
+
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
28 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
29 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
30 |
+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
31 |
+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
32 |
+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
33 |
+
|
34 |
+
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
35 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
build.toml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[torch]
|
2 |
+
name = "activation"
|
3 |
+
src = [
|
4 |
+
"ext-torch/registration.h",
|
5 |
+
"ext-torch/torch_binding.cpp",
|
6 |
+
"ext-torch/torch_binding.h"
|
7 |
+
]
|
8 |
+
pysrc = [
|
9 |
+
"ext-torch/__init__.py"
|
10 |
+
]
|
11 |
+
|
12 |
+
[kernel.activation]
|
13 |
+
capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
|
14 |
+
src = [
|
15 |
+
"activation/activation_kernels.cu",
|
16 |
+
"activation/cuda_compat.h",
|
17 |
+
"activation/dispatch_utils.h",
|
18 |
+
]
|
ext-torch/__init__.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import activation._activation
|
4 |
+
|
5 |
+
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
6 |
+
torch.ops._activation.silu_and_mul(out, x)
|
7 |
+
|
8 |
+
|
9 |
+
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
10 |
+
torch.ops._activation.gelu_and_mul(out, x)
|
11 |
+
|
12 |
+
|
13 |
+
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
14 |
+
torch.ops._activation.gelu_tanh_and_mul(out, x)
|
15 |
+
|
16 |
+
|
17 |
+
def fatrelu_and_mul(out: torch.Tensor,
|
18 |
+
x: torch.Tensor,
|
19 |
+
threshold: float = 0.0) -> None:
|
20 |
+
torch.ops._activation.fatrelu_and_mul(out, x, threshold)
|
21 |
+
|
22 |
+
|
23 |
+
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
24 |
+
torch.ops._activation.gelu_fast(out, x)
|
25 |
+
|
26 |
+
|
27 |
+
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
|
28 |
+
torch.ops._activation.gelu_new(out, x)
|
29 |
+
|
30 |
+
|
31 |
+
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
32 |
+
torch.ops._activation.gelu_quick(out, x)
|
ext-torch/registration.h
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <Python.h>
|
4 |
+
|
5 |
+
#define _CONCAT(A, B) A##B
|
6 |
+
#define CONCAT(A, B) _CONCAT(A, B)
|
7 |
+
|
8 |
+
#define _STRINGIFY(A) #A
|
9 |
+
#define STRINGIFY(A) _STRINGIFY(A)
|
10 |
+
|
11 |
+
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
12 |
+
// could be a macro instead of a literal token.
|
13 |
+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
14 |
+
|
15 |
+
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
16 |
+
// could be a macro instead of a literal token.
|
17 |
+
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
18 |
+
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
19 |
+
|
20 |
+
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
21 |
+
// via python's import statement.
|
22 |
+
#define REGISTER_EXTENSION(NAME) \
|
23 |
+
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
24 |
+
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
25 |
+
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
26 |
+
return PyModule_Create(&module); \
|
27 |
+
}
|
ext-torch/torch_binding.cpp
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/library.h>
|
2 |
+
|
3 |
+
#include "registration.h"
|
4 |
+
#include "torch_binding.h"
|
5 |
+
|
6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
+
// Activation ops
|
8 |
+
// Activation function used in SwiGLU.
|
9 |
+
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
10 |
+
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
11 |
+
|
12 |
+
// Activation function used in GeGLU with `none` approximation.
|
13 |
+
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
|
14 |
+
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
15 |
+
|
16 |
+
// Activation function used in GeGLU with `tanh` approximation.
|
17 |
+
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
|
18 |
+
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
19 |
+
|
20 |
+
// FATReLU implementation.
|
21 |
+
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
|
22 |
+
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);
|
23 |
+
|
24 |
+
// GELU implementation used in GPT-2.
|
25 |
+
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
|
26 |
+
ops.impl("gelu_new", torch::kCUDA, &gelu_new);
|
27 |
+
|
28 |
+
// Approximate GELU implementation.
|
29 |
+
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
|
30 |
+
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
|
31 |
+
}
|
32 |
+
|
33 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
ext-torch/torch_binding.h
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/torch.h>
|
4 |
+
|
5 |
+
void silu_and_mul(torch::Tensor &out, torch::Tensor &input);
|
6 |
+
|
7 |
+
void gelu_and_mul(torch::Tensor &out, torch::Tensor &input);
|
8 |
+
|
9 |
+
void gelu_tanh_and_mul(torch::Tensor &out, torch::Tensor &input);
|
10 |
+
|
11 |
+
void fatrelu_and_mul(torch::Tensor &out, torch::Tensor &input,
|
12 |
+
double threshold);
|
13 |
+
|
14 |
+
void gelu_new(torch::Tensor &out, torch::Tensor &input);
|
15 |
+
|
16 |
+
void gelu_fast(torch::Tensor &out, torch::Tensor &input);
|
17 |
+
|
18 |
+
void gelu_quick(torch::Tensor &out, torch::Tensor &input);
|