Commit
·
bbf5511
0
Parent(s):
Add rotary kernel
Browse files- README.md +7 -0
- build.toml +17 -0
- rotary/rotary_cuda.cu +45 -0
- torch-ext/registration.h +27 -0
- torch-ext/rotary/__init__.py +19 -0
- torch-ext/torch_binding.cpp +42 -0
README.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: bsd-3-clause
|
3 |
+
---
|
4 |
+
|
5 |
+
## rotary
|
6 |
+
|
7 |
+
rotary embedding kernel from [Flash Attention](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary).
|
build.toml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
version = "0.0.1"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
name = "rotary"
|
6 |
+
src = [
|
7 |
+
"torch-ext/registration.h",
|
8 |
+
"torch-ext/torch_binding.cpp",
|
9 |
+
]
|
10 |
+
pyroot = "torch-ext"
|
11 |
+
|
12 |
+
[kernel.activation]
|
13 |
+
capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
|
14 |
+
src = [
|
15 |
+
"rotary/rotary_cuda.cu",
|
16 |
+
]
|
17 |
+
depends = [ "torch" ]
|
rotary/rotary_cuda.cu
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#include <torch/all.h>
|
6 |
+
#include <ATen/native/TensorIterator.h>
|
7 |
+
#include <ATen/native/cuda/Loops.cuh>
|
8 |
+
|
9 |
+
void apply_rotary_cuda(torch::Tensor const &x1, torch::Tensor const &x2,
|
10 |
+
torch::Tensor const &cos, torch::Tensor const &sin,
|
11 |
+
torch::Tensor &out1, torch::Tensor &out2,
|
12 |
+
bool const conj) {
|
13 |
+
auto iter = at::TensorIteratorConfig()
|
14 |
+
.add_output(out1)
|
15 |
+
.add_output(out2)
|
16 |
+
.add_input(x1)
|
17 |
+
.add_input(x2)
|
18 |
+
.add_input(cos)
|
19 |
+
.add_input(sin)
|
20 |
+
.check_all_same_dtype(false)
|
21 |
+
.promote_inputs_to_common_dtype(false)
|
22 |
+
.build();
|
23 |
+
|
24 |
+
if (!conj) {
|
25 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
|
26 |
+
at::native::gpu_kernel_multiple_outputs(
|
27 |
+
iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
|
28 |
+
scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
|
29 |
+
scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin);
|
30 |
+
scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos);
|
31 |
+
return {out1, out2};
|
32 |
+
});
|
33 |
+
});
|
34 |
+
} else {
|
35 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
|
36 |
+
at::native::gpu_kernel_multiple_outputs(
|
37 |
+
iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
|
38 |
+
scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
|
39 |
+
scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin);
|
40 |
+
scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos);
|
41 |
+
return {out1, out2};
|
42 |
+
});
|
43 |
+
});
|
44 |
+
}
|
45 |
+
}
|
torch-ext/registration.h
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <Python.h>
|
4 |
+
|
5 |
+
#define _CONCAT(A, B) A##B
|
6 |
+
#define CONCAT(A, B) _CONCAT(A, B)
|
7 |
+
|
8 |
+
#define _STRINGIFY(A) #A
|
9 |
+
#define STRINGIFY(A) _STRINGIFY(A)
|
10 |
+
|
11 |
+
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
12 |
+
// could be a macro instead of a literal token.
|
13 |
+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
14 |
+
|
15 |
+
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
16 |
+
// could be a macro instead of a literal token.
|
17 |
+
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
18 |
+
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
19 |
+
|
20 |
+
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
21 |
+
// via python's import statement.
|
22 |
+
#define REGISTER_EXTENSION(NAME) \
|
23 |
+
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
24 |
+
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
25 |
+
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
26 |
+
return PyModule_Create(&module); \
|
27 |
+
}
|
torch-ext/rotary/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from ._ops import ops
|
5 |
+
|
6 |
+
|
7 |
+
def apply_rotary(
|
8 |
+
x1: torch.Tensor,
|
9 |
+
x2: torch.Tensor,
|
10 |
+
cos: torch.Tensor,
|
11 |
+
sin: torch.Tensor,
|
12 |
+
out1: torch.Tensor,
|
13 |
+
out2: torch.Tensor,
|
14 |
+
conj: bool,
|
15 |
+
):
|
16 |
+
ops.apply_rotary(x1, x2, cos, sin, out1, out2, conj)
|
17 |
+
|
18 |
+
|
19 |
+
__all__ = ["apply_rotary"]
|
torch-ext/torch_binding.cpp
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/all.h>
|
2 |
+
#include <c10/cuda/CUDAGuard.h>
|
3 |
+
|
4 |
+
#include "registration.h"
|
5 |
+
|
6 |
+
#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
|
7 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
8 |
+
|
9 |
+
void apply_rotary_cuda(torch::Tensor const &x1, torch::Tensor const &x2,
|
10 |
+
torch::Tensor const &cos, torch::Tensor const &sin,
|
11 |
+
torch::Tensor &out1, torch::Tensor &out2,
|
12 |
+
bool const conj);
|
13 |
+
|
14 |
+
void apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2,
|
15 |
+
torch::Tensor const &cos, torch::Tensor const &sin,
|
16 |
+
torch::Tensor &out1, torch::Tensor &out2,
|
17 |
+
bool const conj) {
|
18 |
+
CHECK_DEVICE(x1); CHECK_DEVICE(x2);
|
19 |
+
CHECK_DEVICE(cos); CHECK_DEVICE(sin);
|
20 |
+
CHECK_DEVICE(out1); CHECK_DEVICE(out1);
|
21 |
+
TORCH_CHECK(x1.dtype() == x2.dtype());
|
22 |
+
TORCH_CHECK(cos.dtype() == sin.dtype());
|
23 |
+
TORCH_CHECK(out1.dtype() == out2.dtype());
|
24 |
+
TORCH_CHECK(x1.dtype() == cos.dtype());
|
25 |
+
TORCH_CHECK(x1.dtype() == out1.dtype());
|
26 |
+
TORCH_CHECK(x1.sizes() == x2.sizes());
|
27 |
+
TORCH_CHECK(cos.sizes() == sin.sizes());
|
28 |
+
TORCH_CHECK(out1.sizes() == out2.sizes());
|
29 |
+
|
30 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
31 |
+
at::cuda::CUDAGuard device_guard{x1.device()};
|
32 |
+
|
33 |
+
apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj);
|
34 |
+
}
|
35 |
+
|
36 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
37 |
+
ops.def("apply_rotary(Tensor x1, Tensor x2, Tensor cos, Tensor sin,"
|
38 |
+
"Tensor! out1, Tensor! out2, bool conj) -> ()");
|
39 |
+
ops.impl("apply_rotary", torch::kCUDA, &apply_rotary);
|
40 |
+
}
|
41 |
+
|
42 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|