activation / tests /kernels /test_activation.py
danieldk's picture
danieldk HF staff
Fix a couple of bugs and add tests from vLLM
5600c5f
import math
import random
from typing import Type
import activation
import pytest
import torch
import torch.nn.functional as F
from .utils import opcheck
from .allclose_default import get_default_atol, get_default_rtol
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 13824] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
def gelu_fast(x: torch.Tensor) -> torch.Tensor:
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
def gelu_new(x: torch.Tensor) -> torch.Tensor:
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
def gelu_quick(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x)
def fatrelu_and_mul(x: torch.Tensor, threshold: float) -> torch.Tensor:
d = x.shape[-1] // 2
x1 = x[..., :d]
x2 = x[..., d:]
x1 = F.threshold(x1, threshold, 0.0)
return x1 * x2
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
@pytest.mark.parametrize("activation_name", ["silu", "gelu", "gelu_tanh", "fatrelu"])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_act_and_mul(
activation_name: str,
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation_name == "silu":
torch_fn = silu_and_mul
fn = activation.silu_and_mul
op = activation.ops.silu_and_mul
elif activation_name == "gelu":
torch_fn = lambda x: gelu_and_mul(x, "none")
fn = activation.gelu_and_mul
op = activation.ops.gelu_and_mul
elif activation_name == "gelu_tanh":
torch_fn = lambda x: gelu_and_mul(x, "tanh")
fn = activation.gelu_tanh_and_mul
op = activation.ops.gelu_tanh_and_mul
elif activation_name == "fatrelu":
threshold = random.uniform(0, 1)
torch_fn = lambda x: fatrelu_and_mul(x, threshold)
fn = lambda out, x: activation.fatrelu_and_mul(out, x, threshold)
op = activation.ops.fatrelu_and_mul
out_shape = x.shape[:-1] + (x.shape[-1] // 2,)
out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
out = fn(out, x)
ref_out = torch_fn(x)
# The SiLU, GELU and FatReLU implementations are equivalent to the native
# PyTorch implementations, so we can do exact comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if activation_name == "fatrelu":
opcheck(op, (out, x, threshold))
else:
opcheck(op, (out, x))
@pytest.mark.parametrize(
"activation_fns",
[
(gelu_fast, activation.gelu_fast, activation.ops.gelu_fast),
(gelu_new, activation.gelu_new, activation.ops.gelu_new),
(gelu_quick, activation.gelu_quick, activation.ops.gelu_quick),
],
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_activation(
activation_fns,
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
torch.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
torch_fn, fn, op = activation_fns
out = fn(torch.empty_like(x), x)
ref_out = torch_fn(x)
torch.testing.assert_close(
out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
)
out = torch.empty_like(x)
opcheck(op, (out, x))