File size: 4,376 Bytes
5600c5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import math
import random
from typing import Type
import activation
import pytest
import torch
import torch.nn.functional as F
from .utils import opcheck
from .allclose_default import get_default_atol, get_default_rtol
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 13824] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
def gelu_fast(x: torch.Tensor) -> torch.Tensor:
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
def gelu_new(x: torch.Tensor) -> torch.Tensor:
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
def gelu_quick(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x)
def fatrelu_and_mul(x: torch.Tensor, threshold: float) -> torch.Tensor:
d = x.shape[-1] // 2
x1 = x[..., :d]
x2 = x[..., d:]
x1 = F.threshold(x1, threshold, 0.0)
return x1 * x2
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
@pytest.mark.parametrize("activation_name", ["silu", "gelu", "gelu_tanh", "fatrelu"])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_act_and_mul(
activation_name: str,
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation_name == "silu":
torch_fn = silu_and_mul
fn = activation.silu_and_mul
op = activation.ops.silu_and_mul
elif activation_name == "gelu":
torch_fn = lambda x: gelu_and_mul(x, "none")
fn = activation.gelu_and_mul
op = activation.ops.gelu_and_mul
elif activation_name == "gelu_tanh":
torch_fn = lambda x: gelu_and_mul(x, "tanh")
fn = activation.gelu_tanh_and_mul
op = activation.ops.gelu_tanh_and_mul
elif activation_name == "fatrelu":
threshold = random.uniform(0, 1)
torch_fn = lambda x: fatrelu_and_mul(x, threshold)
fn = lambda out, x: activation.fatrelu_and_mul(out, x, threshold)
op = activation.ops.fatrelu_and_mul
out_shape = x.shape[:-1] + (x.shape[-1] // 2,)
out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
out = fn(out, x)
ref_out = torch_fn(x)
# The SiLU, GELU and FatReLU implementations are equivalent to the native
# PyTorch implementations, so we can do exact comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if activation_name == "fatrelu":
opcheck(op, (out, x, threshold))
else:
opcheck(op, (out, x))
@pytest.mark.parametrize(
"activation_fns",
[
(gelu_fast, activation.gelu_fast, activation.ops.gelu_fast),
(gelu_new, activation.gelu_new, activation.ops.gelu_new),
(gelu_quick, activation.gelu_quick, activation.ops.gelu_quick),
],
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_activation(
activation_fns,
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
torch.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
torch_fn, fn, op = activation_fns
out = fn(torch.empty_like(x), x)
ref_out = torch_fn(x)
torch.testing.assert_close(
out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
)
out = torch.empty_like(x)
opcheck(op, (out, x))
|