File size: 16,383 Bytes
474addc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 |
import math
import torch
import triton
import triton.language as tl
# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
# # This config has a race condition when EVEN_M == False, disabling it for now.
# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
# ],
# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
# )
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _fwd_eva_agg_kernel(
Q,
K,
V,
RFA_K,
RFA_V,
WindowMask,
Out,
softmax_scale,
stride_qb, stride_qh, stride_qm,
stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn,
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
stride_mb, stride_mm,
stride_ob, stride_oh, stride_om,
nheads,
seqlen_q,
seqlen_k,
nchunks,
headdim,
CACHE_KEY_SEQLEN_Q, # TODO: why keeping this
CACHE_KEY_SEQLEN_K, # TODO: why keeping this
CACHE_KEY_NCHUNKS, # TODO: why keeping this
CHUNKS_PER_WINDOW: tl.constexpr,
WINDOW_SIZE: tl.constexpr,
MASK_TYPE: tl.constexpr,
EMPTY_RFA_KV: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_W: tl.constexpr,
EVEN_C: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_bh = tl.program_id(1)
off_h = off_bh % nheads
off_b = off_bh // nheads
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
offs_n = tl.arange(0, BLOCK_N)
offs_c = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# TODO: add paratheses or not
q_ptrs = (
Q +
off_b * stride_qb +
off_h * stride_qh +
(offs_m[:, None] * stride_qm + offs_d[None, :])
)
k_ptrs = (
K +
off_b * stride_kb +
off_h * stride_kh +
(offs_n[:, None] * stride_kn + offs_d[None, :])
)
v_ptrs = (
V +
off_b * stride_vb +
off_h * stride_vh +
(offs_n[:, None] * stride_vn + offs_d[None, :])
)
if EMPTY_RFA_KV == 0:
rfa_k_ptrs = (
RFA_K +
off_b * stride_rfa_kb +
off_h * stride_rfa_kh +
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
)
rfa_v_ptrs = (
RFA_V +
off_b * stride_rfa_vb +
off_h * stride_rfa_vh +
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
)
qk_scale = softmax_scale
qk_scale *= 1.4426950408889634 # log2(e)
if MASK_TYPE == 1:
m_ptrs = (
WindowMask +
off_b * stride_mb +
(offs_m[:, None] * stride_mm + offs_n[None, :])
)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
d_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(
q_ptrs
)
else:
q = tl.load(
q_ptrs,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
q = tl.load(
q_ptrs,
mask=offs_m[:, None] < seqlen_q,
other=0.0
)
else:
q = tl.load(
q_ptrs,
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0
)
# loop over k, v and update accumulator
# Iterate over local singletons;
# so we only iterate over blocks within the current window
start_idx_n = offs_w * WINDOW_SIZE
end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
# Trying to combine the two masks seem to make the result wrong
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if MASK_TYPE == 1:
if EVEN_M & EVEN_W:
mask = tl.load(
m_ptrs + start_n - start_idx_n
).to(tl.float32)
else:
mask = tl.load(
m_ptrs + start_n - start_idx_n,
mask=(offs_m[:, None] < seqlen_q)
& ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
other=0.0,
).to(tl.float32)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
# we assume mask already implies the causal masking
qk = qk * qk_scale + mask
m_ij = tl.maximum(tl.max(qk, 1), m_i)
p = tl.exp2(qk - m_ij[:, None])
else:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
p = tl.exp2(qk * qk_scale - m_ij[:, None])
d_ij = tl.sum(p, 1)
# scale acc_o
prev_scale = tl.exp2(m_i - m_ij)
# # -- update output accumulator --
acc_o = acc_o * prev_scale[:, None]
# update acc_o
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
p = p.to(v.dtype)
acc_o = tl.dot(p, v, acc_o)
# -- update statistics
d_i = d_i * prev_scale + d_ij
m_i = m_ij
if EMPTY_RFA_KV == 0:
# Iterate over RFA chunks
# we only iterate over chunks before the current local singleton window
end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
for start_c in range(0, end_idx_c, BLOCK_N):
start_c = tl.multiple_of(start_c, BLOCK_N)
# -- compute qk ----
if EVEN_C & EVEN_M:
if EVEN_HEADDIM:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc
)
else:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc,
mask=(start_c + offs_c)[:, None] < nchunks,
other=0.0,
)
else:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc,
mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(rfa_k))
# Trying to combine the two masks seem to make the result wrong
if not EVEN_C: # Need to mask out otherwise the softmax is wrong
qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
p = tl.exp2(qk * qk_scale - m_ij[:, None])
d_ij = tl.sum(p, 1)
# scale acc_o
prev_scale = tl.exp2(m_i - m_ij)
# # -- update output accumulator --
acc_o = acc_o * prev_scale[:, None]
# update acc_o
# TODO: If we just do "if EVEN_N", there seems to be some race condition ?
if EVEN_C & EVEN_M:
if EVEN_HEADDIM:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc
)
else:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc,
mask=(start_c + offs_n)[:, None] < nchunks,
other=0.0,
)
else:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc,
mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
other=0.0,
)
p = p.to(rfa_v.dtype)
acc_o = tl.dot(p, rfa_v, acc_o)
# -- update statistics
d_i = d_i * prev_scale + d_ij
m_i = m_ij
# BUG: have to store and immediately load
acc_o = acc_o / d_i[:, None]
# TODO: understand why rematerialize offsets to save registers?
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_HEADDIM)
out_ptrs = (
Out +
off_b * stride_ob +
off_h * stride_oh +
(offs_m[:, None] * stride_om + offs_d[None, :])
)
if EVEN_M:
if EVEN_HEADDIM:
tl.store(
out_ptrs, acc_o
)
else:
tl.store(
out_ptrs, acc_o,
mask=offs_d[None, :] < headdim
)
else:
if EVEN_HEADDIM:
tl.store(
out_ptrs, acc_o,
mask=offs_m[:, None] < seqlen_q
)
else:
tl.store(
out_ptrs, acc_o,
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
def triton_eva_agg_fwd(q, k, v, rfa_k, rfa_v, window_mask, softmax_scale, window_size, chunks_per_window):
if rfa_k is None and rfa_v is None:
empty_rfa_kv = 1
q, k, v = [
x if x.stride(-1) == 1 else x.contiguous()
for x in [q, k, v]
]
else:
assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
empty_rfa_kv = 0
q, k, v, rfa_k, rfa_v = [
x if x.stride(-1) == 1 else x.contiguous()
for x in [q, k, v, rfa_k, rfa_v]
]
# shape constraints
batch, nheads, seqlen_q, head_dim = q.shape
_, _, seqlen_k, _ = k.shape
if empty_rfa_kv == 0:
nchunks = rfa_k.shape[-2]
assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
else:
nchunks = 0
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
assert k.shape == (batch, nheads, seqlen_k, head_dim)
assert v.shape == (batch, nheads, seqlen_k, head_dim)
assert head_dim <= 128, "We only test head dimensions up to 128"
# assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
assert q.is_cuda and k.is_cuda and v.is_cuda
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
mask_type = 0
if window_mask is not None:
mask_type = 1
assert window_mask.dtype == q.dtype, torch.float
assert window_mask.is_cuda
assert window_mask.dim() == 4
assert window_mask.shape == (batch, 1, seqlen_q, window_size)
if window_mask.stride(-1) != 1:
window_mask = window_mask.contiguous()
mask_strides = (
(window_mask.stride(0), window_mask.stride(2))
if mask_type == 1 else
(0, 0)
)
rfa_k_strides = (
(rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
if empty_rfa_kv == 0 else
(0, 0, 0)
)
rfa_v_strides = (
(rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
if empty_rfa_kv == 0 else
(0, 0, 0)
)
assert chunks_per_window > 0, "chunks_per_window must be greater than 0"
o = torch.empty_like(q)
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
if q.dtype == torch.float:
BLOCK = 64
else:
BLOCK = 128
num_warps = 4 if head_dim <= 64 else 8
assert chunks_per_window >= BLOCK, "chunks_per_window must be greater than BLOCK"
# WINDOW_MASK_TYPE:
# - 0: regular causal mask, simply None
# - 1: the shape must be B, 1, W, I, J
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_fwd_eva_agg_kernel[grid](
q,
k,
v,
rfa_k,
rfa_v,
window_mask,
o,
softmax_scale,
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
mask_strides[0], mask_strides[1],
o.stride(0), o.stride(1), o.stride(2),
nheads,
seqlen_q,
seqlen_k,
nchunks,
head_dim,
seqlen_q // 32,
seqlen_k // 32,
nchunks // 32,
chunks_per_window,
window_size,
mask_type,
empty_rfa_kv,
BLOCK_HEADDIM,
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return o
|