File size: 4,376 Bytes
5600c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import math
import random
from typing import Type

import activation
import pytest
import torch
import torch.nn.functional as F

from .utils import opcheck
from .allclose_default import get_default_atol, get_default_rtol

DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048]  # Arbitrary values for testing
D = [512, 13824]  # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]


def gelu_fast(x: torch.Tensor) -> torch.Tensor:
    return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))


def gelu_new(x: torch.Tensor) -> torch.Tensor:
    c = math.sqrt(2.0 / math.pi)
    return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))


def gelu_quick(x: torch.Tensor) -> torch.Tensor:
    return x * torch.sigmoid(1.702 * x)


def fatrelu_and_mul(x: torch.Tensor, threshold: float) -> torch.Tensor:
    d = x.shape[-1] // 2
    x1 = x[..., :d]
    x2 = x[..., d:]
    x1 = F.threshold(x1, threshold, 0.0)
    return x1 * x2


def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
    d = x.shape[-1] // 2
    return F.silu(x[..., :d]) * x[..., d:]


def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor:
    d = x.shape[-1] // 2
    return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]


@pytest.mark.parametrize("activation_name", ["silu", "gelu", "gelu_tanh", "fatrelu"])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_act_and_mul(
    activation_name: str,
    num_tokens: int,
    d: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    torch.set_default_device(device)
    x = torch.randn(num_tokens, 2 * d, dtype=dtype)
    if activation_name == "silu":
        torch_fn = silu_and_mul
        fn = activation.silu_and_mul
        op = activation.ops.silu_and_mul
    elif activation_name == "gelu":
        torch_fn = lambda x: gelu_and_mul(x, "none")
        fn = activation.gelu_and_mul
        op = activation.ops.gelu_and_mul
    elif activation_name == "gelu_tanh":
        torch_fn = lambda x: gelu_and_mul(x, "tanh")
        fn = activation.gelu_tanh_and_mul
        op = activation.ops.gelu_tanh_and_mul
    elif activation_name == "fatrelu":
        threshold = random.uniform(0, 1)
        torch_fn = lambda x: fatrelu_and_mul(x, threshold)
        fn = lambda out, x: activation.fatrelu_and_mul(out, x, threshold)
        op = activation.ops.fatrelu_and_mul

    out_shape = x.shape[:-1] + (x.shape[-1] // 2,)
    out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
    out = fn(out, x)
    ref_out = torch_fn(x)

    # The SiLU, GELU and FatReLU implementations are equivalent to the native
    # PyTorch implementations, so we can do exact comparison.
    torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)

    d = x.shape[-1] // 2
    output_shape = x.shape[:-1] + (d,)
    out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
    if activation_name == "fatrelu":
        opcheck(op, (out, x, threshold))
    else:
        opcheck(op, (out, x))


@pytest.mark.parametrize(
    "activation_fns",
    [
        (gelu_fast, activation.gelu_fast, activation.ops.gelu_fast),
        (gelu_new, activation.gelu_new, activation.ops.gelu_new),
        (gelu_quick, activation.gelu_quick, activation.ops.gelu_quick),
    ],
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_activation(
    activation_fns,
    num_tokens: int,
    d: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    torch.manual_seed(seed)
    torch.set_default_device(device)
    x = torch.randn(num_tokens, d, dtype=dtype)
    torch_fn, fn, op = activation_fns
    out = fn(torch.empty_like(x), x)
    ref_out = torch_fn(x)
    torch.testing.assert_close(
        out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
    )

    out = torch.empty_like(x)
    opcheck(op, (out, x))