kernel
danieldk HF staff commited on
Commit
bbf5511
·
0 Parent(s):

Add rotary kernel

Browse files
README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ ---
4
+
5
+ ## rotary
6
+
7
+ rotary embedding kernel from [Flash Attention](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary).
build.toml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ version = "0.0.1"
3
+
4
+ [torch]
5
+ name = "rotary"
6
+ src = [
7
+ "torch-ext/registration.h",
8
+ "torch-ext/torch_binding.cpp",
9
+ ]
10
+ pyroot = "torch-ext"
11
+
12
+ [kernel.activation]
13
+ capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
14
+ src = [
15
+ "rotary/rotary_cuda.cu",
16
+ ]
17
+ depends = [ "torch" ]
rotary/rotary_cuda.cu ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <torch/all.h>
6
+ #include <ATen/native/TensorIterator.h>
7
+ #include <ATen/native/cuda/Loops.cuh>
8
+
9
+ void apply_rotary_cuda(torch::Tensor const &x1, torch::Tensor const &x2,
10
+ torch::Tensor const &cos, torch::Tensor const &sin,
11
+ torch::Tensor &out1, torch::Tensor &out2,
12
+ bool const conj) {
13
+ auto iter = at::TensorIteratorConfig()
14
+ .add_output(out1)
15
+ .add_output(out2)
16
+ .add_input(x1)
17
+ .add_input(x2)
18
+ .add_input(cos)
19
+ .add_input(sin)
20
+ .check_all_same_dtype(false)
21
+ .promote_inputs_to_common_dtype(false)
22
+ .build();
23
+
24
+ if (!conj) {
25
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
26
+ at::native::gpu_kernel_multiple_outputs(
27
+ iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
28
+ scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
29
+ scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin);
30
+ scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos);
31
+ return {out1, out2};
32
+ });
33
+ });
34
+ } else {
35
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
36
+ at::native::gpu_kernel_multiple_outputs(
37
+ iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
38
+ scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
39
+ scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin);
40
+ scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos);
41
+ return {out1, out2};
42
+ });
43
+ });
44
+ }
45
+ }
torch-ext/registration.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <Python.h>
4
+
5
+ #define _CONCAT(A, B) A##B
6
+ #define CONCAT(A, B) _CONCAT(A, B)
7
+
8
+ #define _STRINGIFY(A) #A
9
+ #define STRINGIFY(A) _STRINGIFY(A)
10
+
11
+ // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
12
+ // could be a macro instead of a literal token.
13
+ #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
14
+
15
+ // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
16
+ // could be a macro instead of a literal token.
17
+ #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
18
+ TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
19
+
20
+ // REGISTER_EXTENSION allows the shared library to be loaded and initialized
21
+ // via python's import statement.
22
+ #define REGISTER_EXTENSION(NAME) \
23
+ PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
24
+ static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
25
+ STRINGIFY(NAME), nullptr, 0, nullptr}; \
26
+ return PyModule_Create(&module); \
27
+ }
torch-ext/rotary/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ def apply_rotary(
8
+ x1: torch.Tensor,
9
+ x2: torch.Tensor,
10
+ cos: torch.Tensor,
11
+ sin: torch.Tensor,
12
+ out1: torch.Tensor,
13
+ out2: torch.Tensor,
14
+ conj: bool,
15
+ ):
16
+ ops.apply_rotary(x1, x2, cos, sin, out1, out2, conj)
17
+
18
+
19
+ __all__ = ["apply_rotary"]
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+ #include <c10/cuda/CUDAGuard.h>
3
+
4
+ #include "registration.h"
5
+
6
+ #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
7
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
8
+
9
+ void apply_rotary_cuda(torch::Tensor const &x1, torch::Tensor const &x2,
10
+ torch::Tensor const &cos, torch::Tensor const &sin,
11
+ torch::Tensor &out1, torch::Tensor &out2,
12
+ bool const conj);
13
+
14
+ void apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2,
15
+ torch::Tensor const &cos, torch::Tensor const &sin,
16
+ torch::Tensor &out1, torch::Tensor &out2,
17
+ bool const conj) {
18
+ CHECK_DEVICE(x1); CHECK_DEVICE(x2);
19
+ CHECK_DEVICE(cos); CHECK_DEVICE(sin);
20
+ CHECK_DEVICE(out1); CHECK_DEVICE(out1);
21
+ TORCH_CHECK(x1.dtype() == x2.dtype());
22
+ TORCH_CHECK(cos.dtype() == sin.dtype());
23
+ TORCH_CHECK(out1.dtype() == out2.dtype());
24
+ TORCH_CHECK(x1.dtype() == cos.dtype());
25
+ TORCH_CHECK(x1.dtype() == out1.dtype());
26
+ TORCH_CHECK(x1.sizes() == x2.sizes());
27
+ TORCH_CHECK(cos.sizes() == sin.sizes());
28
+ TORCH_CHECK(out1.sizes() == out2.sizes());
29
+
30
+ // Otherwise the kernel will be launched from cuda:0 device
31
+ at::cuda::CUDAGuard device_guard{x1.device()};
32
+
33
+ apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj);
34
+ }
35
+
36
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
37
+ ops.def("apply_rotary(Tensor x1, Tensor x2, Tensor cos, Tensor sin,"
38
+ "Tensor! out1, Tensor! out2, bool conj) -> ()");
39
+ ops.impl("apply_rotary", torch::kCUDA, &apply_rotary);
40
+ }
41
+
42
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)