EvaByte-SFT / eva_agg_kernel.py
linzheng's picture
Upload EvaByteForCausalLM
474addc verified
raw
history blame
16.4 kB
import math
import torch
import triton
import triton.language as tl
# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
# # This config has a race condition when EVEN_M == False, disabling it for now.
# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
# ],
# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
# )
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _fwd_eva_agg_kernel(
Q,
K,
V,
RFA_K,
RFA_V,
WindowMask,
Out,
softmax_scale,
stride_qb, stride_qh, stride_qm,
stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn,
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
stride_mb, stride_mm,
stride_ob, stride_oh, stride_om,
nheads,
seqlen_q,
seqlen_k,
nchunks,
headdim,
CACHE_KEY_SEQLEN_Q, # TODO: why keeping this
CACHE_KEY_SEQLEN_K, # TODO: why keeping this
CACHE_KEY_NCHUNKS, # TODO: why keeping this
CHUNKS_PER_WINDOW: tl.constexpr,
WINDOW_SIZE: tl.constexpr,
MASK_TYPE: tl.constexpr,
EMPTY_RFA_KV: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_W: tl.constexpr,
EVEN_C: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_bh = tl.program_id(1)
off_h = off_bh % nheads
off_b = off_bh // nheads
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
offs_n = tl.arange(0, BLOCK_N)
offs_c = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# TODO: add paratheses or not
q_ptrs = (
Q +
off_b * stride_qb +
off_h * stride_qh +
(offs_m[:, None] * stride_qm + offs_d[None, :])
)
k_ptrs = (
K +
off_b * stride_kb +
off_h * stride_kh +
(offs_n[:, None] * stride_kn + offs_d[None, :])
)
v_ptrs = (
V +
off_b * stride_vb +
off_h * stride_vh +
(offs_n[:, None] * stride_vn + offs_d[None, :])
)
if EMPTY_RFA_KV == 0:
rfa_k_ptrs = (
RFA_K +
off_b * stride_rfa_kb +
off_h * stride_rfa_kh +
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
)
rfa_v_ptrs = (
RFA_V +
off_b * stride_rfa_vb +
off_h * stride_rfa_vh +
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
)
qk_scale = softmax_scale
qk_scale *= 1.4426950408889634 # log2(e)
if MASK_TYPE == 1:
m_ptrs = (
WindowMask +
off_b * stride_mb +
(offs_m[:, None] * stride_mm + offs_n[None, :])
)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
d_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(
q_ptrs
)
else:
q = tl.load(
q_ptrs,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
q = tl.load(
q_ptrs,
mask=offs_m[:, None] < seqlen_q,
other=0.0
)
else:
q = tl.load(
q_ptrs,
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0
)
# loop over k, v and update accumulator
# Iterate over local singletons;
# so we only iterate over blocks within the current window
start_idx_n = offs_w * WINDOW_SIZE
end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
# Trying to combine the two masks seem to make the result wrong
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if MASK_TYPE == 1:
if EVEN_M & EVEN_W:
mask = tl.load(
m_ptrs + start_n - start_idx_n
).to(tl.float32)
else:
mask = tl.load(
m_ptrs + start_n - start_idx_n,
mask=(offs_m[:, None] < seqlen_q)
& ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
other=0.0,
).to(tl.float32)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
# we assume mask already implies the causal masking
qk = qk * qk_scale + mask
m_ij = tl.maximum(tl.max(qk, 1), m_i)
p = tl.exp2(qk - m_ij[:, None])
else:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
p = tl.exp2(qk * qk_scale - m_ij[:, None])
d_ij = tl.sum(p, 1)
# scale acc_o
prev_scale = tl.exp2(m_i - m_ij)
# # -- update output accumulator --
acc_o = acc_o * prev_scale[:, None]
# update acc_o
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
p = p.to(v.dtype)
acc_o = tl.dot(p, v, acc_o)
# -- update statistics
d_i = d_i * prev_scale + d_ij
m_i = m_ij
if EMPTY_RFA_KV == 0:
# Iterate over RFA chunks
# we only iterate over chunks before the current local singleton window
end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
for start_c in range(0, end_idx_c, BLOCK_N):
start_c = tl.multiple_of(start_c, BLOCK_N)
# -- compute qk ----
if EVEN_C & EVEN_M:
if EVEN_HEADDIM:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc
)
else:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc,
mask=(start_c + offs_c)[:, None] < nchunks,
other=0.0,
)
else:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc,
mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(rfa_k))
# Trying to combine the two masks seem to make the result wrong
if not EVEN_C: # Need to mask out otherwise the softmax is wrong
qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
p = tl.exp2(qk * qk_scale - m_ij[:, None])
d_ij = tl.sum(p, 1)
# scale acc_o
prev_scale = tl.exp2(m_i - m_ij)
# # -- update output accumulator --
acc_o = acc_o * prev_scale[:, None]
# update acc_o
# TODO: If we just do "if EVEN_N", there seems to be some race condition ?
if EVEN_C & EVEN_M:
if EVEN_HEADDIM:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc
)
else:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc,
mask=(start_c + offs_n)[:, None] < nchunks,
other=0.0,
)
else:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc,
mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
other=0.0,
)
p = p.to(rfa_v.dtype)
acc_o = tl.dot(p, rfa_v, acc_o)
# -- update statistics
d_i = d_i * prev_scale + d_ij
m_i = m_ij
# BUG: have to store and immediately load
acc_o = acc_o / d_i[:, None]
# TODO: understand why rematerialize offsets to save registers?
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_HEADDIM)
out_ptrs = (
Out +
off_b * stride_ob +
off_h * stride_oh +
(offs_m[:, None] * stride_om + offs_d[None, :])
)
if EVEN_M:
if EVEN_HEADDIM:
tl.store(
out_ptrs, acc_o
)
else:
tl.store(
out_ptrs, acc_o,
mask=offs_d[None, :] < headdim
)
else:
if EVEN_HEADDIM:
tl.store(
out_ptrs, acc_o,
mask=offs_m[:, None] < seqlen_q
)
else:
tl.store(
out_ptrs, acc_o,
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
def triton_eva_agg_fwd(q, k, v, rfa_k, rfa_v, window_mask, softmax_scale, window_size, chunks_per_window):
if rfa_k is None and rfa_v is None:
empty_rfa_kv = 1
q, k, v = [
x if x.stride(-1) == 1 else x.contiguous()
for x in [q, k, v]
]
else:
assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
empty_rfa_kv = 0
q, k, v, rfa_k, rfa_v = [
x if x.stride(-1) == 1 else x.contiguous()
for x in [q, k, v, rfa_k, rfa_v]
]
# shape constraints
batch, nheads, seqlen_q, head_dim = q.shape
_, _, seqlen_k, _ = k.shape
if empty_rfa_kv == 0:
nchunks = rfa_k.shape[-2]
assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
else:
nchunks = 0
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
assert k.shape == (batch, nheads, seqlen_k, head_dim)
assert v.shape == (batch, nheads, seqlen_k, head_dim)
assert head_dim <= 128, "We only test head dimensions up to 128"
# assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
assert q.is_cuda and k.is_cuda and v.is_cuda
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
mask_type = 0
if window_mask is not None:
mask_type = 1
assert window_mask.dtype == q.dtype, torch.float
assert window_mask.is_cuda
assert window_mask.dim() == 4
assert window_mask.shape == (batch, 1, seqlen_q, window_size)
if window_mask.stride(-1) != 1:
window_mask = window_mask.contiguous()
mask_strides = (
(window_mask.stride(0), window_mask.stride(2))
if mask_type == 1 else
(0, 0)
)
rfa_k_strides = (
(rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
if empty_rfa_kv == 0 else
(0, 0, 0)
)
rfa_v_strides = (
(rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
if empty_rfa_kv == 0 else
(0, 0, 0)
)
assert chunks_per_window > 0, "chunks_per_window must be greater than 0"
o = torch.empty_like(q)
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
if q.dtype == torch.float:
BLOCK = 64
else:
BLOCK = 128
num_warps = 4 if head_dim <= 64 else 8
assert chunks_per_window >= BLOCK, "chunks_per_window must be greater than BLOCK"
# WINDOW_MASK_TYPE:
# - 0: regular causal mask, simply None
# - 1: the shape must be B, 1, W, I, J
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_fwd_eva_agg_kernel[grid](
q,
k,
v,
rfa_k,
rfa_v,
window_mask,
o,
softmax_scale,
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
mask_strides[0], mask_strides[1],
o.stride(0), o.stride(1), o.stride(2),
nheads,
seqlen_q,
seqlen_k,
nchunks,
head_dim,
seqlen_q // 32,
seqlen_k // 32,
nchunks // 32,
chunks_per_window,
window_size,
mask_type,
empty_rfa_kv,
BLOCK_HEADDIM,
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return o