|
|
|
import math |
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.heuristics( |
|
{ |
|
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, |
|
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, |
|
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0, |
|
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0, |
|
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], |
|
} |
|
) |
|
@triton.jit |
|
def _fwd_eva_agg_kernel( |
|
Q, |
|
K, |
|
V, |
|
RFA_K, |
|
RFA_V, |
|
WindowMask, |
|
Out, |
|
softmax_scale, |
|
stride_qb, stride_qh, stride_qm, |
|
stride_kb, stride_kh, stride_kn, |
|
stride_vb, stride_vh, stride_vn, |
|
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc, |
|
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc, |
|
stride_mb, stride_mm, |
|
stride_ob, stride_oh, stride_om, |
|
nheads, |
|
seqlen_q, |
|
seqlen_k, |
|
nchunks, |
|
headdim, |
|
CACHE_KEY_SEQLEN_Q, |
|
CACHE_KEY_SEQLEN_K, |
|
CACHE_KEY_NCHUNKS, |
|
CHUNKS_PER_WINDOW: tl.constexpr, |
|
WINDOW_SIZE: tl.constexpr, |
|
MASK_TYPE: tl.constexpr, |
|
EMPTY_RFA_KV: tl.constexpr, |
|
BLOCK_HEADDIM: tl.constexpr, |
|
EVEN_M: tl.constexpr, |
|
EVEN_N: tl.constexpr, |
|
EVEN_W: tl.constexpr, |
|
EVEN_C: tl.constexpr, |
|
EVEN_HEADDIM: tl.constexpr, |
|
BLOCK_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
): |
|
start_m = tl.program_id(0) |
|
off_bh = tl.program_id(1) |
|
off_h = off_bh % nheads |
|
off_b = off_bh // nheads |
|
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
offs_w = (start_m * BLOCK_M) // WINDOW_SIZE |
|
offs_n = tl.arange(0, BLOCK_N) |
|
offs_c = tl.arange(0, BLOCK_N) |
|
offs_d = tl.arange(0, BLOCK_HEADDIM) |
|
|
|
q_ptrs = ( |
|
Q + |
|
off_b * stride_qb + |
|
off_h * stride_qh + |
|
(offs_m[:, None] * stride_qm + offs_d[None, :]) |
|
) |
|
k_ptrs = ( |
|
K + |
|
off_b * stride_kb + |
|
off_h * stride_kh + |
|
(offs_n[:, None] * stride_kn + offs_d[None, :]) |
|
) |
|
v_ptrs = ( |
|
V + |
|
off_b * stride_vb + |
|
off_h * stride_vh + |
|
(offs_n[:, None] * stride_vn + offs_d[None, :]) |
|
) |
|
if EMPTY_RFA_KV == 0: |
|
rfa_k_ptrs = ( |
|
RFA_K + |
|
off_b * stride_rfa_kb + |
|
off_h * stride_rfa_kh + |
|
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :]) |
|
) |
|
rfa_v_ptrs = ( |
|
RFA_V + |
|
off_b * stride_rfa_vb + |
|
off_h * stride_rfa_vh + |
|
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :]) |
|
) |
|
|
|
qk_scale = softmax_scale |
|
qk_scale *= 1.4426950408889634 |
|
if MASK_TYPE == 1: |
|
m_ptrs = ( |
|
WindowMask + |
|
off_b * stride_mb + |
|
(offs_m[:, None] * stride_mm + offs_n[None, :]) |
|
) |
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") |
|
d_i = tl.zeros([BLOCK_M], dtype=tl.float32) |
|
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) |
|
|
|
|
|
|
|
if EVEN_M & EVEN_N: |
|
if EVEN_HEADDIM: |
|
q = tl.load( |
|
q_ptrs |
|
) |
|
else: |
|
q = tl.load( |
|
q_ptrs, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
q = tl.load( |
|
q_ptrs, |
|
mask=offs_m[:, None] < seqlen_q, |
|
other=0.0 |
|
) |
|
else: |
|
q = tl.load( |
|
q_ptrs, |
|
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), |
|
other=0.0 |
|
) |
|
|
|
|
|
|
|
start_idx_n = offs_w * WINDOW_SIZE |
|
end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) |
|
for start_n in range(start_idx_n, end_idx_n, BLOCK_N): |
|
start_n = tl.multiple_of(start_n, BLOCK_N) |
|
|
|
if EVEN_N & EVEN_M: |
|
if EVEN_HEADDIM: |
|
k = tl.load( |
|
k_ptrs + start_n * stride_kn |
|
) |
|
else: |
|
k = tl.load( |
|
k_ptrs + start_n * stride_kn, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
k = tl.load( |
|
k_ptrs + start_n * stride_kn, |
|
mask=(start_n + offs_n)[:, None] < seqlen_k, |
|
other=0.0, |
|
) |
|
else: |
|
k = tl.load( |
|
k_ptrs + start_n * stride_kn, |
|
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), |
|
other=0.0, |
|
) |
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
|
qk += tl.dot(q, tl.trans(k)) |
|
|
|
if not EVEN_N: |
|
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) |
|
|
|
if MASK_TYPE == 1: |
|
if EVEN_M & EVEN_W: |
|
mask = tl.load( |
|
m_ptrs + start_n - start_idx_n |
|
).to(tl.float32) |
|
else: |
|
mask = tl.load( |
|
m_ptrs + start_n - start_idx_n, |
|
mask=(offs_m[:, None] < seqlen_q) |
|
& ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE), |
|
other=0.0, |
|
).to(tl.float32) |
|
|
|
|
|
|
|
|
|
qk = qk * qk_scale + mask |
|
m_ij = tl.maximum(tl.max(qk, 1), m_i) |
|
p = tl.exp2(qk - m_ij[:, None]) |
|
else: |
|
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) |
|
m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i) |
|
p = tl.exp2(qk * qk_scale - m_ij[:, None]) |
|
|
|
d_ij = tl.sum(p, 1) |
|
|
|
|
|
prev_scale = tl.exp2(m_i - m_ij) |
|
|
|
acc_o = acc_o * prev_scale[:, None] |
|
|
|
if EVEN_N & EVEN_M: |
|
if EVEN_HEADDIM: |
|
v = tl.load( |
|
v_ptrs + start_n * stride_vn |
|
) |
|
else: |
|
v = tl.load( |
|
v_ptrs + start_n * stride_vn, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
v = tl.load( |
|
v_ptrs + start_n * stride_vn, |
|
mask=(start_n + offs_n)[:, None] < seqlen_k, |
|
other=0.0, |
|
) |
|
else: |
|
v = tl.load( |
|
v_ptrs + start_n * stride_vn, |
|
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), |
|
other=0.0, |
|
) |
|
p = p.to(v.dtype) |
|
acc_o = tl.dot(p, v, acc_o) |
|
|
|
|
|
d_i = d_i * prev_scale + d_ij |
|
m_i = m_ij |
|
|
|
if EMPTY_RFA_KV == 0: |
|
|
|
|
|
end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks) |
|
for start_c in range(0, end_idx_c, BLOCK_N): |
|
start_c = tl.multiple_of(start_c, BLOCK_N) |
|
|
|
if EVEN_C & EVEN_M: |
|
if EVEN_HEADDIM: |
|
rfa_k = tl.load( |
|
rfa_k_ptrs + start_c * stride_rfa_kc |
|
) |
|
else: |
|
rfa_k = tl.load( |
|
rfa_k_ptrs + start_c * stride_rfa_kc, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
rfa_k = tl.load( |
|
rfa_k_ptrs + start_c * stride_rfa_kc, |
|
mask=(start_c + offs_c)[:, None] < nchunks, |
|
other=0.0, |
|
) |
|
else: |
|
rfa_k = tl.load( |
|
rfa_k_ptrs + start_c * stride_rfa_kc, |
|
mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim), |
|
other=0.0, |
|
) |
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
|
qk += tl.dot(q, tl.trans(rfa_k)) |
|
|
|
if not EVEN_C: |
|
qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf")) |
|
|
|
m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i) |
|
p = tl.exp2(qk * qk_scale - m_ij[:, None]) |
|
|
|
d_ij = tl.sum(p, 1) |
|
|
|
|
|
prev_scale = tl.exp2(m_i - m_ij) |
|
|
|
acc_o = acc_o * prev_scale[:, None] |
|
|
|
|
|
if EVEN_C & EVEN_M: |
|
if EVEN_HEADDIM: |
|
rfa_v = tl.load( |
|
rfa_v_ptrs + start_c * stride_rfa_vc |
|
) |
|
else: |
|
rfa_v = tl.load( |
|
rfa_v_ptrs + start_c * stride_rfa_vc, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
rfa_v = tl.load( |
|
rfa_v_ptrs + start_c * stride_rfa_vc, |
|
mask=(start_c + offs_n)[:, None] < nchunks, |
|
other=0.0, |
|
) |
|
else: |
|
rfa_v = tl.load( |
|
rfa_v_ptrs + start_c * stride_rfa_vc, |
|
mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim), |
|
other=0.0, |
|
) |
|
p = p.to(rfa_v.dtype) |
|
acc_o = tl.dot(p, rfa_v, acc_o) |
|
|
|
|
|
d_i = d_i * prev_scale + d_ij |
|
m_i = m_ij |
|
|
|
|
|
acc_o = acc_o / d_i[:, None] |
|
|
|
start_m = tl.program_id(0) |
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
offs_d = tl.arange(0, BLOCK_HEADDIM) |
|
out_ptrs = ( |
|
Out + |
|
off_b * stride_ob + |
|
off_h * stride_oh + |
|
(offs_m[:, None] * stride_om + offs_d[None, :]) |
|
) |
|
if EVEN_M: |
|
if EVEN_HEADDIM: |
|
tl.store( |
|
out_ptrs, acc_o |
|
) |
|
else: |
|
tl.store( |
|
out_ptrs, acc_o, |
|
mask=offs_d[None, :] < headdim |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
tl.store( |
|
out_ptrs, acc_o, |
|
mask=offs_m[:, None] < seqlen_q |
|
) |
|
else: |
|
tl.store( |
|
out_ptrs, acc_o, |
|
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) |
|
) |
|
|
|
def triton_eva_agg_fwd(q, k, v, rfa_k, rfa_v, window_mask, softmax_scale, window_size, chunks_per_window): |
|
if rfa_k is None and rfa_v is None: |
|
empty_rfa_kv = 1 |
|
|
|
q, k, v = [ |
|
x if x.stride(-1) == 1 else x.contiguous() |
|
for x in [q, k, v] |
|
] |
|
else: |
|
assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time." |
|
empty_rfa_kv = 0 |
|
|
|
q, k, v, rfa_k, rfa_v = [ |
|
x if x.stride(-1) == 1 else x.contiguous() |
|
for x in [q, k, v, rfa_k, rfa_v] |
|
] |
|
|
|
|
|
batch, nheads, seqlen_q, head_dim = q.shape |
|
_, _, seqlen_k, _ = k.shape |
|
if empty_rfa_kv == 0: |
|
nchunks = rfa_k.shape[-2] |
|
assert rfa_k.shape == (batch, nheads, nchunks, head_dim) |
|
assert rfa_v.shape == (batch, nheads, nchunks, head_dim) |
|
assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype |
|
else: |
|
nchunks = 0 |
|
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" |
|
assert k.shape == (batch, nheads, seqlen_k, head_dim) |
|
assert v.shape == (batch, nheads, seqlen_k, head_dim) |
|
|
|
assert head_dim <= 128, "We only test head dimensions up to 128" |
|
|
|
assert q.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now" |
|
assert q.is_cuda and k.is_cuda and v.is_cuda |
|
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim) |
|
|
|
mask_type = 0 |
|
if window_mask is not None: |
|
mask_type = 1 |
|
assert window_mask.dtype == q.dtype, torch.float |
|
assert window_mask.is_cuda |
|
assert window_mask.dim() == 4 |
|
assert window_mask.shape == (batch, 1, seqlen_q, window_size) |
|
if window_mask.stride(-1) != 1: |
|
window_mask = window_mask.contiguous() |
|
mask_strides = ( |
|
(window_mask.stride(0), window_mask.stride(2)) |
|
if mask_type == 1 else |
|
(0, 0) |
|
) |
|
|
|
rfa_k_strides = ( |
|
(rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2)) |
|
if empty_rfa_kv == 0 else |
|
(0, 0, 0) |
|
) |
|
rfa_v_strides = ( |
|
(rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2)) |
|
if empty_rfa_kv == 0 else |
|
(0, 0, 0) |
|
) |
|
assert chunks_per_window > 0, "chunks_per_window must be greater than 0" |
|
|
|
o = torch.empty_like(q) |
|
|
|
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) |
|
if q.dtype == torch.float: |
|
BLOCK = 64 |
|
else: |
|
BLOCK = 128 |
|
num_warps = 4 if head_dim <= 64 else 8 |
|
assert chunks_per_window >= BLOCK, "chunks_per_window must be greater than BLOCK" |
|
|
|
|
|
|
|
|
|
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) |
|
_fwd_eva_agg_kernel[grid]( |
|
q, |
|
k, |
|
v, |
|
rfa_k, |
|
rfa_v, |
|
window_mask, |
|
o, |
|
softmax_scale, |
|
q.stride(0), q.stride(1), q.stride(2), |
|
k.stride(0), k.stride(1), k.stride(2), |
|
v.stride(0), v.stride(1), v.stride(2), |
|
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2], |
|
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2], |
|
mask_strides[0], mask_strides[1], |
|
o.stride(0), o.stride(1), o.stride(2), |
|
nheads, |
|
seqlen_q, |
|
seqlen_k, |
|
nchunks, |
|
head_dim, |
|
seqlen_q // 32, |
|
seqlen_k // 32, |
|
nchunks // 32, |
|
chunks_per_window, |
|
window_size, |
|
mask_type, |
|
empty_rfa_kv, |
|
BLOCK_HEADDIM, |
|
BLOCK_M=BLOCK, |
|
BLOCK_N=BLOCK, |
|
num_warps=num_warps, |
|
num_stages=1, |
|
) |
|
return o |
|
|