|
import math |
|
import random |
|
from typing import Type |
|
|
|
import activation |
|
import pytest |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from .utils import opcheck |
|
from .allclose_default import get_default_atol, get_default_rtol |
|
|
|
DTYPES = [torch.half, torch.bfloat16, torch.float] |
|
NUM_TOKENS = [7, 83, 2048] |
|
D = [512, 13824] |
|
SEEDS = [0] |
|
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] |
|
|
|
|
|
def gelu_fast(x: torch.Tensor) -> torch.Tensor: |
|
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) |
|
|
|
|
|
def gelu_new(x: torch.Tensor) -> torch.Tensor: |
|
c = math.sqrt(2.0 / math.pi) |
|
return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0)))) |
|
|
|
|
|
def gelu_quick(x: torch.Tensor) -> torch.Tensor: |
|
return x * torch.sigmoid(1.702 * x) |
|
|
|
|
|
def fatrelu_and_mul(x: torch.Tensor, threshold: float) -> torch.Tensor: |
|
d = x.shape[-1] // 2 |
|
x1 = x[..., :d] |
|
x2 = x[..., d:] |
|
x1 = F.threshold(x1, threshold, 0.0) |
|
return x1 * x2 |
|
|
|
|
|
def silu_and_mul(x: torch.Tensor) -> torch.Tensor: |
|
d = x.shape[-1] // 2 |
|
return F.silu(x[..., :d]) * x[..., d:] |
|
|
|
|
|
def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor: |
|
d = x.shape[-1] // 2 |
|
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:] |
|
|
|
|
|
@pytest.mark.parametrize("activation_name", ["silu", "gelu", "gelu_tanh", "fatrelu"]) |
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
|
@pytest.mark.parametrize("d", D) |
|
@pytest.mark.parametrize("dtype", DTYPES) |
|
@pytest.mark.parametrize("seed", SEEDS) |
|
@pytest.mark.parametrize("device", CUDA_DEVICES) |
|
@torch.inference_mode() |
|
def test_act_and_mul( |
|
activation_name: str, |
|
num_tokens: int, |
|
d: int, |
|
dtype: torch.dtype, |
|
seed: int, |
|
device: str, |
|
) -> None: |
|
random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.set_default_device(device) |
|
x = torch.randn(num_tokens, 2 * d, dtype=dtype) |
|
if activation_name == "silu": |
|
torch_fn = silu_and_mul |
|
fn = activation.silu_and_mul |
|
op = activation.ops.silu_and_mul |
|
elif activation_name == "gelu": |
|
torch_fn = lambda x: gelu_and_mul(x, "none") |
|
fn = activation.gelu_and_mul |
|
op = activation.ops.gelu_and_mul |
|
elif activation_name == "gelu_tanh": |
|
torch_fn = lambda x: gelu_and_mul(x, "tanh") |
|
fn = activation.gelu_tanh_and_mul |
|
op = activation.ops.gelu_tanh_and_mul |
|
elif activation_name == "fatrelu": |
|
threshold = random.uniform(0, 1) |
|
torch_fn = lambda x: fatrelu_and_mul(x, threshold) |
|
fn = lambda out, x: activation.fatrelu_and_mul(out, x, threshold) |
|
op = activation.ops.fatrelu_and_mul |
|
|
|
out_shape = x.shape[:-1] + (x.shape[-1] // 2,) |
|
out = torch.empty(out_shape, dtype=x.dtype, device=x.device) |
|
out = fn(out, x) |
|
ref_out = torch_fn(x) |
|
|
|
|
|
|
|
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) |
|
|
|
d = x.shape[-1] // 2 |
|
output_shape = x.shape[:-1] + (d,) |
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) |
|
if activation_name == "fatrelu": |
|
opcheck(op, (out, x, threshold)) |
|
else: |
|
opcheck(op, (out, x)) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"activation_fns", |
|
[ |
|
(gelu_fast, activation.gelu_fast, activation.ops.gelu_fast), |
|
(gelu_new, activation.gelu_new, activation.ops.gelu_new), |
|
(gelu_quick, activation.gelu_quick, activation.ops.gelu_quick), |
|
], |
|
) |
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
|
@pytest.mark.parametrize("d", D) |
|
@pytest.mark.parametrize("dtype", DTYPES) |
|
@pytest.mark.parametrize("seed", SEEDS) |
|
@pytest.mark.parametrize("device", CUDA_DEVICES) |
|
@torch.inference_mode() |
|
def test_activation( |
|
activation_fns, |
|
num_tokens: int, |
|
d: int, |
|
dtype: torch.dtype, |
|
seed: int, |
|
device: str, |
|
) -> None: |
|
torch.manual_seed(seed) |
|
torch.set_default_device(device) |
|
x = torch.randn(num_tokens, d, dtype=dtype) |
|
torch_fn, fn, op = activation_fns |
|
out = fn(torch.empty_like(x), x) |
|
ref_out = torch_fn(x) |
|
torch.testing.assert_close( |
|
out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out) |
|
) |
|
|
|
out = torch.empty_like(x) |
|
opcheck(op, (out, x)) |
|
|