Fix a couple of bugs and add tests from vLLM
Browse files- ext-torch/__init__.py +35 -29
- ext-torch/torch_binding.cpp +4 -0
- tests/__init__.py +0 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/allclose_default.py +14 -0
- tests/kernels/test_activation.py +139 -0
- tests/kernels/utils.py +73 -0
ext-torch/__init__.py
CHANGED
@@ -6,36 +6,42 @@ except ImportError as e:
|
|
6 |
# Fallback for local development.
|
7 |
try:
|
8 |
import _activation
|
|
|
9 |
ops = torch.ops._activition
|
10 |
except ImportError:
|
11 |
raise e
|
12 |
-
|
13 |
-
|
14 |
-
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
15 |
-
ops.silu_and_mul(out, x)
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
41 |
ops.gelu_quick(out, x)
|
|
|
|
6 |
# Fallback for local development.
|
7 |
try:
|
8 |
import _activation
|
9 |
+
|
10 |
ops = torch.ops._activition
|
11 |
except ImportError:
|
12 |
raise e
|
13 |
+
|
14 |
+
|
15 |
+
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
16 |
+
ops.silu_and_mul(out, x)
|
17 |
+
return out
|
18 |
+
|
19 |
+
|
20 |
+
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
21 |
+
ops.gelu_and_mul(out, x)
|
22 |
+
return out
|
23 |
+
|
24 |
+
|
25 |
+
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
26 |
+
ops.gelu_tanh_and_mul(out, x)
|
27 |
+
return out
|
28 |
+
|
29 |
+
|
30 |
+
def fatrelu_and_mul(out: torch.Tensor, x: torch.Tensor, threshold: float = 0.0) -> None:
|
31 |
+
ops.fatrelu_and_mul(out, x, threshold)
|
32 |
+
return out
|
33 |
+
|
34 |
+
|
35 |
+
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
36 |
+
ops.gelu_fast(out, x)
|
37 |
+
return out
|
38 |
+
|
39 |
+
|
40 |
+
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
|
41 |
+
ops.gelu_new(out, x)
|
42 |
+
return out
|
43 |
+
|
44 |
+
|
45 |
+
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
46 |
ops.gelu_quick(out, x)
|
47 |
+
return out
|
ext-torch/torch_binding.cpp
CHANGED
@@ -28,6 +28,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
28 |
// Approximate GELU implementation.
|
29 |
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
|
30 |
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
|
|
|
|
|
|
|
|
|
31 |
}
|
32 |
|
33 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
28 |
// Approximate GELU implementation.
|
29 |
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
|
30 |
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
|
31 |
+
|
32 |
+
// Quick GELU implementation.
|
33 |
+
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
34 |
+
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
35 |
}
|
36 |
|
37 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
tests/__init__.py
ADDED
File without changes
|
tests/kernels/__init__.py
ADDED
File without changes
|
tests/kernels/allclose_default.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# Reference default values of atol and rtol are from
|
4 |
+
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
|
5 |
+
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
|
6 |
+
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6}
|
7 |
+
|
8 |
+
|
9 |
+
def get_default_atol(output) -> float:
|
10 |
+
return default_atol[output.dtype]
|
11 |
+
|
12 |
+
|
13 |
+
def get_default_rtol(output) -> float:
|
14 |
+
return default_rtol[output.dtype]
|
tests/kernels/test_activation.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
from typing import Type
|
4 |
+
|
5 |
+
import activation
|
6 |
+
import pytest
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from .utils import opcheck
|
11 |
+
from .allclose_default import get_default_atol, get_default_rtol
|
12 |
+
|
13 |
+
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
14 |
+
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
15 |
+
D = [512, 13824] # Arbitrary values for testing
|
16 |
+
SEEDS = [0]
|
17 |
+
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
18 |
+
|
19 |
+
|
20 |
+
def gelu_fast(x: torch.Tensor) -> torch.Tensor:
|
21 |
+
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
|
22 |
+
|
23 |
+
|
24 |
+
def gelu_new(x: torch.Tensor) -> torch.Tensor:
|
25 |
+
c = math.sqrt(2.0 / math.pi)
|
26 |
+
return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
|
27 |
+
|
28 |
+
|
29 |
+
def gelu_quick(x: torch.Tensor) -> torch.Tensor:
|
30 |
+
return x * torch.sigmoid(1.702 * x)
|
31 |
+
|
32 |
+
|
33 |
+
def fatrelu_and_mul(x: torch.Tensor, threshold: float) -> torch.Tensor:
|
34 |
+
d = x.shape[-1] // 2
|
35 |
+
x1 = x[..., :d]
|
36 |
+
x2 = x[..., d:]
|
37 |
+
x1 = F.threshold(x1, threshold, 0.0)
|
38 |
+
return x1 * x2
|
39 |
+
|
40 |
+
|
41 |
+
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
42 |
+
d = x.shape[-1] // 2
|
43 |
+
return F.silu(x[..., :d]) * x[..., d:]
|
44 |
+
|
45 |
+
|
46 |
+
def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor:
|
47 |
+
d = x.shape[-1] // 2
|
48 |
+
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
|
49 |
+
|
50 |
+
|
51 |
+
@pytest.mark.parametrize("activation_name", ["silu", "gelu", "gelu_tanh", "fatrelu"])
|
52 |
+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
53 |
+
@pytest.mark.parametrize("d", D)
|
54 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
55 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
56 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
57 |
+
@torch.inference_mode()
|
58 |
+
def test_act_and_mul(
|
59 |
+
activation_name: str,
|
60 |
+
num_tokens: int,
|
61 |
+
d: int,
|
62 |
+
dtype: torch.dtype,
|
63 |
+
seed: int,
|
64 |
+
device: str,
|
65 |
+
) -> None:
|
66 |
+
random.seed(seed)
|
67 |
+
torch.manual_seed(seed)
|
68 |
+
torch.set_default_device(device)
|
69 |
+
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
70 |
+
if activation_name == "silu":
|
71 |
+
torch_fn = silu_and_mul
|
72 |
+
fn = activation.silu_and_mul
|
73 |
+
op = activation.ops.silu_and_mul
|
74 |
+
elif activation_name == "gelu":
|
75 |
+
torch_fn = lambda x: gelu_and_mul(x, "none")
|
76 |
+
fn = activation.gelu_and_mul
|
77 |
+
op = activation.ops.gelu_and_mul
|
78 |
+
elif activation_name == "gelu_tanh":
|
79 |
+
torch_fn = lambda x: gelu_and_mul(x, "tanh")
|
80 |
+
fn = activation.gelu_tanh_and_mul
|
81 |
+
op = activation.ops.gelu_tanh_and_mul
|
82 |
+
elif activation_name == "fatrelu":
|
83 |
+
threshold = random.uniform(0, 1)
|
84 |
+
torch_fn = lambda x: fatrelu_and_mul(x, threshold)
|
85 |
+
fn = lambda out, x: activation.fatrelu_and_mul(out, x, threshold)
|
86 |
+
op = activation.ops.fatrelu_and_mul
|
87 |
+
|
88 |
+
out_shape = x.shape[:-1] + (x.shape[-1] // 2,)
|
89 |
+
out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
|
90 |
+
out = fn(out, x)
|
91 |
+
ref_out = torch_fn(x)
|
92 |
+
|
93 |
+
# The SiLU, GELU and FatReLU implementations are equivalent to the native
|
94 |
+
# PyTorch implementations, so we can do exact comparison.
|
95 |
+
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
|
96 |
+
|
97 |
+
d = x.shape[-1] // 2
|
98 |
+
output_shape = x.shape[:-1] + (d,)
|
99 |
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
100 |
+
if activation_name == "fatrelu":
|
101 |
+
opcheck(op, (out, x, threshold))
|
102 |
+
else:
|
103 |
+
opcheck(op, (out, x))
|
104 |
+
|
105 |
+
|
106 |
+
@pytest.mark.parametrize(
|
107 |
+
"activation_fns",
|
108 |
+
[
|
109 |
+
(gelu_fast, activation.gelu_fast, activation.ops.gelu_fast),
|
110 |
+
(gelu_new, activation.gelu_new, activation.ops.gelu_new),
|
111 |
+
(gelu_quick, activation.gelu_quick, activation.ops.gelu_quick),
|
112 |
+
],
|
113 |
+
)
|
114 |
+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
115 |
+
@pytest.mark.parametrize("d", D)
|
116 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
117 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
118 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
119 |
+
@torch.inference_mode()
|
120 |
+
def test_activation(
|
121 |
+
activation_fns,
|
122 |
+
num_tokens: int,
|
123 |
+
d: int,
|
124 |
+
dtype: torch.dtype,
|
125 |
+
seed: int,
|
126 |
+
device: str,
|
127 |
+
) -> None:
|
128 |
+
torch.manual_seed(seed)
|
129 |
+
torch.set_default_device(device)
|
130 |
+
x = torch.randn(num_tokens, d, dtype=dtype)
|
131 |
+
torch_fn, fn, op = activation_fns
|
132 |
+
out = fn(torch.empty_like(x), x)
|
133 |
+
ref_out = torch_fn(x)
|
134 |
+
torch.testing.assert_close(
|
135 |
+
out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
|
136 |
+
)
|
137 |
+
|
138 |
+
out = torch.empty_like(x)
|
139 |
+
opcheck(op, (out, x))
|
tests/kernels/utils.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Kernel test utils"""
|
2 |
+
|
3 |
+
import itertools
|
4 |
+
import random
|
5 |
+
import unittest
|
6 |
+
from numbers import Number
|
7 |
+
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
8 |
+
|
9 |
+
import pytest
|
10 |
+
import torch
|
11 |
+
from torch._prims_common import TensorLikeType
|
12 |
+
|
13 |
+
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
14 |
+
# bugs related to this test in PyTorch 2.4.
|
15 |
+
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
16 |
+
"test_schema",
|
17 |
+
"test_autograd_registration",
|
18 |
+
"test_faketensor",
|
19 |
+
)
|
20 |
+
|
21 |
+
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
22 |
+
"test_schema",
|
23 |
+
"test_autograd_registration",
|
24 |
+
"test_faketensor",
|
25 |
+
"test_aot_dispatch_dynamic",
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
# Copied/modified from torch._refs.__init__.py
|
30 |
+
def fp8_allclose(
|
31 |
+
a: TensorLikeType,
|
32 |
+
b: TensorLikeType,
|
33 |
+
rtol: float = 1e-05,
|
34 |
+
atol: float = 1e-08,
|
35 |
+
equal_nan: bool = False,
|
36 |
+
) -> bool:
|
37 |
+
"""
|
38 |
+
Reference implementation of torch.allclose
|
39 |
+
"""
|
40 |
+
torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
|
41 |
+
|
42 |
+
return bool(
|
43 |
+
torch.all(
|
44 |
+
torch.isclose(
|
45 |
+
a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
|
46 |
+
)
|
47 |
+
).item()
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
# A special version of op check that has a restricted default set of test_utils
|
52 |
+
# and a patched version of allclose that supports fp8 types.
|
53 |
+
def opcheck(
|
54 |
+
op: Union[
|
55 |
+
torch._ops.OpOverload,
|
56 |
+
torch._ops.OpOverloadPacket,
|
57 |
+
torch._library.custom_ops.CustomOpDef,
|
58 |
+
],
|
59 |
+
args: Tuple[Any, ...],
|
60 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
61 |
+
*,
|
62 |
+
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
63 |
+
raise_exception: bool = True,
|
64 |
+
cond: bool = True
|
65 |
+
) -> Dict[str, str]:
|
66 |
+
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
67 |
+
return (
|
68 |
+
torch.library.opcheck(
|
69 |
+
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
70 |
+
)
|
71 |
+
if cond
|
72 |
+
else {}
|
73 |
+
)
|