MInference / minference /modules /minference_forward.py
iofu728's picture
Feature(MInference): update the pycuda
24083d5
# Copyright (c) 2024 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import inspect
import json
import os
from importlib import import_module
from transformers.models.llama.modeling_llama import *
from transformers.utils.import_utils import _is_package_available
if _is_package_available("vllm"):
from vllm.attention.backends.flash_attn import *
from ..ops.block_sparse_flash_attention import block_sparse_attention
from ..ops.pit_sparse_flash_attention_v2 import vertical_slash_sparse_attention
from ..ops.streaming_kernel import streaming_forward, streaming_forward2
from .snap_kv import *
last_q = 64
arange = torch.arange(last_q, device="cuda")
LAST_Q_MASK = arange[None, None, :, None] >= arange[None, None, None, :]
ROPE_TYPE = None
SEARCH_MASK = None
def init_minference_parameters(self):
config = self.config.to_dict()
self.starting_layer = config.get("starting_layer", 0)
self.is_search = config.get("is_search", False)
# self.n_init = config.get("n_init", 128)
# self.n_local = config.get("n_local", 3968)
self.ne_inf = None
self.config_path = config.get("config_path", "")
if os.path.exists(self.config_path) and self.layer_idx < len(json.load(open(self.config_path))):
self.best_pattern = {int(ii): jj for ii, jj in json.load(open(self.config_path))[self.layer_idx].items()}
else:
self.best_pattern = {}
self.vertical, self.slash = None, None
# import apply_rotary_pos_emb
if "apply_rotary_pos_emb" not in self.__dict__:
global apply_rotary_pos_emb
model_path = self.rotary_emb.__class__.__module__
apply_rotary_pos_emb = getattr(import_module(model_path), "apply_rotary_pos_emb")
self.apply_rotary_pos_emb = True
def sum_all_diagonal_matrix(mat: torch.tensor):
b, h, n, m = mat.shape
zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding
mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right
mat_strided = mat_padded.as_strided((1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides
sum_diags = torch.sum(mat_strided, 2) # Sums the resulting matrix's columns
return sum_diags[:,:,1:]
def gather(t, dim, i):
"""A broadcasting version of torch.gather."""
dim += (dim < 0) * t.ndim
return t.gather(dim, i.expand(*t.shape[:dim], i.shape[dim], *t.shape[dim + 1 :]))
def gather_qkv(q, k, v, attention_mask):
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(q.size(-1)) + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
return attn_output
def search_pattern(q, k, head):
q_len = q.shape[2]
head_dim = q.shape[-1]
def vertical_and_slash(vertical_size, slash_size):
last_q = 64
q_len = q.shape[2]
qk_idxs = [ii + q_len for ii in list(range(-last_q, 0, 1))]
qk = torch.matmul(q[:,:,qk_idxs,:], k.transpose(2, 3))/ math.sqrt(head_dim) + attention_mask[:,:,qk_idxs]
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
vertical = qk.sum(-2, keepdim=True)
vertical[...,:30] = 10000
vertical_topk = torch.topk(-vertical, q_len - vertical_size, -1).indices
slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
slash[...,-30:] = 10000
slash_topk = slash
slash = torch.topk(slash, slash_size, -1).indices - (q_len - 1)
slash = torch.stack([torch.sparse.spdiags(torch.ones(slash_size, q_len), slash.cpu()[0][_], (q_len, q_len)).to_dense() for _ in range(1)]).to(q.device)
est_attn = torch.ones_like(attn_weights)
dim = 3
est_attn = est_attn.scatter(3, vertical_topk.expand(*est_attn.shape[:dim], vertical_topk.shape[dim], *est_attn.shape[dim + 1 :]), 0)
est_attn = est_attn + slash
est_attn = (est_attn > 0).float()
est_attn = torch.tril(est_attn)
attn_weights_x = attn_weights * est_attn
res3 = attn_weights_x[:,:,2500:].sum(-1).mean(-1).squeeze().float().detach().cpu().numpy()
return res3
def stream_llm(vertical_size, slash_size):
q_len = q.shape[2]
mask = torch.triu(torch.tril(torch.ones(q_len, q_len), 0), -slash_size).to(q)
mask[:,:vertical_size] = 1
mask = mask.unsqueeze(0).unsqueeze(1)
est_attn = torch.tril(mask)
attn_weights_x = attn_weights * est_attn
res3 = attn_weights_x[:,:,2500:].sum(-1).mean(-1).squeeze().float().detach().cpu().numpy()
return res3
def block_sparse(topk_ratio, slash_size=None):
block_num = (q_len -1) // 32 + 1
block_q = torch.zeros(1,1,block_num * 32,head_dim).to(q)
block_q[:,:,:q_len] = q
block_q = block_q.reshape(1,1,block_num,32,-1).mean(-2)
block_k = torch.zeros(1,1,block_num * 32,head_dim).to(k)
block_k[:,:,:q_len] = k
block_k = block_k.reshape(1,1,block_num,32,-1).mean(-2)
qk = torch.matmul(block_q, block_k.transpose(2, 3)) + attention_mask[:,:,:block_num,:block_num]
est_attn = torch.ones_like(qk)
block_topk = torch.topk(-qk, block_num - block_num//topk_ratio, -1).indices
dim = 3
est_attn = est_attn.scatter(3, block_topk.expand(*est_attn.shape[:dim], block_topk.shape[dim], *est_attn.shape[dim + 1 :]), 0)
est_attn = est_attn.unsqueeze(3).unsqueeze(-1).repeat(1,1,1,32,1,32).reshape(1,1,block_num * 32, block_num * 32)[...,:q_len,:q_len]
est_attn = torch.tril(est_attn)
attn_weights_x = attn_weights * est_attn
res2 = attn_weights_x[:,:,2500:].sum(-1).mean(-1).squeeze().float().detach().cpu().numpy()
return res2
global SEARCH_MASK
if SEARCH_MASK is None:
attention_mask = torch.full((q_len, q_len), torch.finfo(q.dtype).min, device="cuda")
mask_cond = torch.arange(attention_mask.size(-1), device="cuda")
attention_mask.masked_fill_(mask_cond < (mask_cond + 1).view(attention_mask.size(-1), 1), 0)
attention_mask = attention_mask[None, None, :]
SEARCH_MASK = attention_mask
else:
attention_mask = SEARCH_MASK
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim) + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
best_s, best_v, best_score, best_ty = 0, 0, 0, ""
all_info = []
for ty, fc in [("stream_llm", stream_llm), ("vertical_and_slash", vertical_and_slash), ("block_sparse", block_sparse)]:
if ty == "stream_llm":
vs_list = [(100, 800)]
elif ty == "vertical_and_slash":
vs_list = [(30, 800), (100, 750), (500, 700), (3500, 100)]
else:
vs_list = [(8, 1)]
for v_size, s_size in vs_list:
score = fc(v_size, s_size)
score = score.item()
all_info.append([ty, v_size, s_size, score])
if score > best_score:
best_score = score
best_s, best_v = s_size, v_size
best_ty = ty
if best_ty == "stream_llm":
best_ty = "vertical_and_slash"
if best_ty == "block_sparse":
best_ty, best_v, best_s = "vertical_and_slash", 1000, 6096
print(head, best_ty, best_v, best_s, best_score)
return (best_ty, best_v, best_s, best_score)
def search_pattern_v2(q, k, v, head):
q_len = q.shape[2]
head_dim = q.shape[-1]
def vertical_and_slash_kernel(q, k, v, vertical_size, slash_size):
vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
last_q = 64
qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k)
qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK, qk[:, :, :, -last_q:], -torch.inf)
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
vertical = qk.sum(-2, keepdim=True)
vertical[...,:30] = torch.inf
vertical_topk = torch.topk(vertical, vertical_size, -1).indices
slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
slash[...,-30:] = torch.inf
slash_topk = slash
slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices
return vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)
def dense(q, k, v, vertical_size=None, slash_size=None):
return flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, head_dim)
def block_sparse_kernel(q, k, v, vertical_size=None, slash_size=None):
topk = 100
return block_sparse_attention(q, k, v, topk)
best_s, best_v, best_score, best_ty = 0, 0, float("inf"), ""
bsz = q.shape[0]
all_info = []
ref = dense(q, k, v)
for ty, fc in [("stream_llm", streaming_forward), ("vertical_and_slash", vertical_and_slash_kernel), ("block_sparse", block_sparse_kernel)]:
if ty == "stream_llm":
vs_list = [(100, 800)]
elif ty == "vertical_and_slash":
vs_list = [(30, 800), (100, 800), (100, 750), (500, 700), (3500, 100), (1000, 4096)]
else:
vs_list = [(10, 1)]
for v_size, s_size in vs_list:
score = fc(q, k, v, v_size, s_size)
# delta = (ref - score).abs().sum()
delta = ((ref - score).abs() > 5e-3).sum()
score = delta.item()
all_info.append([ty, v_size, s_size, score])
if score < best_score:
best_score = score
best_s, best_v = s_size, v_size
best_ty = ty
print(head, best_ty, best_v, best_s, best_score)
return all_info
def shift_matrix(mat):
b, h, _, n = mat.shape
zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding
mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right
mat_strided = mat_padded.as_strided((1, 1, n, n + 2 * n), (1, n * (2 * n + n), 2 * n + n - 1, 1)) # Change the strides
return mat_strided[...,2 * n-1:-1]
def repeat(self, q, k, v, attention_mask):
q_len = q.shape[2]
if q_len == 1:
return gather_qkv(q, k, v, attention_mask)
qk = torch.matmul(q[:,:,-1:,:], k.transpose(2, 3)) / math.sqrt(self.head_dim)
qk = qk.repeat(1,1,q_len, 1)
qk = shift_matrix(qk) + attention_mask
attn_weights = nn.functional.softmax(qk, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
return attn_output
def gather_last_q_vertical_slash_topk_v4(self, q, k, v, head_id):
kv_seq_len = k.size(2)
def vertical_and_slash(attn_weights, vertical_size, slash_size):
last_q = 64
q_len = q.shape[2]
vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
qk_idxs = [ii + q_len for ii in list(range(-last_q, 0, 1))]
qk = torch.matmul(q[:,:,qk_idxs,:], k.transpose(2, 3))/ math.sqrt(self.head_dim) + attention_mask[:,:,qk_idxs]
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
vertical = qk.sum(-2, keepdim=True)
vertical[...,:30] = -self.ne_inf
vertical_topk = torch.topk(-vertical, q_len - vertical_size, -1).indices
slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
slash[...,-30:] = -self.ne_inf
slash_topk = slash
slash = torch.topk(slash, slash_size, -1).indices - (q_len - 1)
slash = torch.stack([torch.sparse.spdiags(torch.ones(slash_size, q_len), slash.cpu()[0][_], (q_len, q_len)).to_dense() for _ in range(1)]).to(q.device)
est_attn = torch.ones_like(attn_weights)
dim = 3
est_attn = est_attn.scatter(3, vertical_topk.expand(*est_attn.shape[:dim], vertical_topk.shape[dim], *est_attn.shape[dim + 1 :]), 0)
est_attn = est_attn + slash
est_attn = (est_attn > 0).float()
est_attn = torch.tril(est_attn)
est_attn = (est_attn == 0).int() * self.ne_inf
attn_weights = attn_weights + est_attn
if self.kv_cache_compressed_v4:
self.vertical = torch.topk(vertical, vertical_size * 4, -1).indices
self.slash = (torch.topk(slash_topk, slash_size * 4, -1).indices - (q_len - 1)).unsqueeze(2)
return attn_weights
def stream_llm(attn_weights, vertical_size, slash_size):
q_len = q.shape[2]
vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
mask = torch.triu(torch.tril(torch.ones(q_len, q_len), 0), -slash_size).to(q)
mask[:,:vertical_size] = 1
mask = mask.unsqueeze(0).unsqueeze(1)
est_attn = torch.tril(mask)
est_attn = (est_attn == 0).int() * self.ne_inf
attn_weights = attn_weights + est_attn
if self.kv_cache_compressed_v4:
self.vertical = torch.Tensor(list(range(vertical_size * 4))).long().to(q.device).unsqueeze(0).unsqueeze(0).unsqueeze(0)
self.slash = torch.Tensor(list(range(-slash_size * 4, 1))).long().to(q.device).unsqueeze(0).unsqueeze(0).unsqueeze(0)
return attn_weights
def block_sparse(attn_weights, topk_ratio, slash_size=None, block_size=8):
block_num = (q_len -1) // block_size + 1
block_q = torch.zeros(1,1,block_num * block_size,head_dim).to(q)
block_q[:,:,:q_len] = q
block_q = block_q.reshape(1,1,block_num,block_size,-1).mean(-2)
block_k = torch.zeros(1,1,block_num * block_size,head_dim).to(k)
block_k[:,:,:q_len] = k
block_k = block_k.reshape(1,1,block_num,block_size,-1).mean(-2)
qk = torch.matmul(block_q, block_k.transpose(2, 3)) + attention_mask[:,:,:block_num,:block_num]
est_attn = torch.ones_like(qk)
block_topk = torch.topk(-qk, block_num - block_num//topk_ratio, -1).indices
dim = 3
est_attn = est_attn.scatter(3, block_topk.expand(*est_attn.shape[:dim], block_topk.shape[dim], *est_attn.shape[dim + 1 :]), 0)
est_attn = est_attn.unsqueeze(3).unsqueeze(-1).repeat(1,1,1,block_size,1,block_size).reshape(1,1,block_num * block_size, block_num * block_size)[...,:q_len,:q_len]
est_attn = torch.tril(est_attn)
est_attn = (est_attn == 0).int()
attn_weights = attn_weights + est_attn
return attn_weights
def dialted(q,k,v, type):
q_len = q.shape[2]
n_init = min(1024, q_len)
vertical_topk = torch.arange(0, n_init, device=q.device)[None, None, None, :]
slash = torch.arange(0, q_len, device=q.device)
if type == 'dilated1':
# 8k local with 1 interval
slash = slash[-8192::2][None, None, :]
elif type == 'dilated2':
# 2k dense local + 4k local with 1 interval
slash = torch.cat([slash[-2048:], slash[-6144:-2048:2]], 0)[None, None, :]
slash = (q_len - 1) - slash
return vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)
def vertical_and_slash_kernel(q, k, v, vertical_size, slash_size):
vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
last_q = min(64, q_len)
qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k)
qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK[...,-last_q:,-last_q:].to(q.device), qk[:, :, :, -last_q:], -torch.inf)
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
vertical = qk.sum(-2, keepdim=True)
vertical[...,:30] = torch.inf
vertical_topk = torch.topk(vertical, vertical_size, -1).indices
slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
slash[...,-100:] = torch.inf
slash_topk = slash
slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices
return vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)
def vertical_and_slash_kernel_static(q, k, v, vertical_size, slash_size):
if "vs" in self.__dict__:
vertical_topk, slash = self.vs
else:
vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
last_q = 64
qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k)
qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK, qk[:, :, :, -last_q:], -torch.inf)
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
vertical = qk.sum(-2, keepdim=True)
vertical[...,:30] = torch.inf
vertical_topk = torch.topk(vertical, vertical_size, -1).indices
slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
slash[...,-30:] = torch.inf
slash_topk = slash
slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices
self.vs = vertical_topk, slash
return vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)
def dense(q, k, v, vertical_size=None, slash_size=None):
return flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, self.head_dim)
def block_sparse_kernel(q, k, v, vertical_size=None, slash_size=None):
topk = 100
return block_sparse_attention(q, k, v, topk)
q_len = q.shape[2]
bsz = q.shape[0]
if self.config.to_dict().get("dilated1", False):
return dialted(q, k, v, 'dilated1')
if self.config.to_dict().get("dilated2", False):
return dialted(q, k, v, 'dilated2')
if self.config.to_dict().get("dense", False):
return dense(q, k, v)
if self.config.to_dict().get("streaming", False):
return streaming_forward(q, k, v, self.config.streaming_kwargs["n_init"], self.config.streaming_kwargs["n_local"])
ty, vertical_size, slash_size, _ = self.best_pattern.get(head_id, ("vertical_and_slash", 1000, 6096, 1))
if self.config.to_dict().get("static_pattern", False):
return vertical_and_slash_kernel_static(q, k, v, vertical_size, slash_size)
if self.config.to_dict().get("vs_only", False):
return vertical_and_slash_kernel(q, k, v, vertical_size, slash_size)
if q_len == 1:
return dense(q, k, v)
fc = {
"stream_llm": streaming_forward,
"vertical_and_slash": vertical_and_slash_kernel,
"block_sparse": block_sparse_kernel,
}[ty]
return fc(q, k, v, vertical_size, slash_size)
def apply_rotary_pos_emb_single(q, cos, sin, position_ids, unsqueeze_dim=1):
# cos = cos[position_ids].unsqueeze(unsqueeze_dim)
# sin = sin[position_ids].unsqueeze(unsqueeze_dim)
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
return (q * cos) + (rotate_half(q) * sin)
def minference_forward():
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
**kwargs,
):
self.init_minference_parameters()
self.ne_inf = torch.finfo(hidden_states.dtype).min
bsz, q_len, _ = hidden_states.size()
if "q_proj" in self.__dict__["_modules"]:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
query_pos = self.num_heads * self.head_dim
query_states, key_states, value_states = torch.split(qkv, query_pos, -1)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
global ROPE_TYPE
if ROPE_TYPE is None:
ROPE_TYPE = "seq_len" in inspect.signature(self.rotary_emb.forward).parameters
if ROPE_TYPE:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
else:
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if self.is_search:
if os.path.exists(self.config_path):
config_list = json.load(open(self.config_path))
if self.layer_idx < len(config_list):
assert False
else:
config_list = []
config = {}
print("Layer", self.layer_idx)
if q_len != 1:
output = torch.empty_like(query_states)
for head in range(query_states.size(1)):
q = query_states[:, head, :, :].unsqueeze(1)
k = key_states[:, head, :, :].unsqueeze(1)
v = value_states[:, head, :, :].unsqueeze(1)
if self.is_search:
config[head] = search_pattern(q, k, head)
if self.layer_idx >= self.starting_layer and not self.is_search:
attn_output = self.gather_last_q_vertical_slash_topk_v4(q, k, v, head)
elif is_flash_attn_2_available():
attn_output = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, self.head_dim)
else:
attn_output = gather_qkv(q, k, v, attention_mask)
output[:, head:head + 1] = attn_output
if self.is_search:
config_list.append(config)
with open(self.config_path, 'w') as json_file:
json.dump(config_list, json_file)
else:
output = flash_attn_func(query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, query_states.size(1), q_len, self.head_dim)
attn_output = output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return forward
def minference_kv_cache_cpu_forward():
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
**kwargs,
):
self.init_minference_parameters()
self.ne_inf = torch.finfo(hidden_states.dtype).min
bsz, q_len, hidden_dim = hidden_states.size()
kv_seq_len = q_len
if use_cache and past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
global ROPE_TYPE
if ROPE_TYPE is None:
ROPE_TYPE = "seq_len" in inspect.signature(self.rotary_emb.forward).parameters
if ROPE_TYPE:
cos, sin = self.rotary_emb(hidden_states, seq_len=kv_seq_len)
else:
cos, sin = self.rotary_emb(hidden_states, position_ids)
cache_kwargs = {"sin": sin, "cos": cos}
attn_out = torch.empty_like(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
act_num_heads = self.num_heads // self.num_key_value_groups
if use_cache:
k = torch.zeros(bsz, act_num_heads, q_len, self.head_dim).to(hidden_states.dtype).cpu()
v = torch.zeros(bsz, act_num_heads, q_len, self.head_dim).to(hidden_states.dtype).cpu()
part_k, part_v = None, None
for head in range(self.num_heads):
if "q_proj" in self.__dict__["_modules"]:
part_q = F.linear(hidden_states, self.q_proj.weight.view(self.num_heads, self.head_dim, hidden_dim)[head]).unsqueeze(2)
else:
part_q = F.linear(hidden_states, self.qkv_proj.weight.view(3, self.num_heads, self.head_dim, hidden_dim)[0][head]).unsqueeze(2)
part_q = apply_rotary_pos_emb_single(part_q.transpose(1, 2), cos, sin, position_ids)
if head % self.num_key_value_groups == 0:
if "q_proj" in self.__dict__["_modules"]:
part_k = F.linear(hidden_states, self.k_proj.weight.view(act_num_heads, self.head_dim, hidden_dim)[head // self.num_key_value_groups]).unsqueeze(2)
part_v = F.linear(hidden_states, self.v_proj.weight.view(act_num_heads, self.head_dim, hidden_dim)[head // self.num_key_value_groups]).unsqueeze(2).transpose(1, 2)
else:
part_k = F.linear(hidden_states, self.qkv_proj.weight.view(3, act_num_heads, self.head_dim, hidden_dim)[1][head // self.num_key_value_groups]).unsqueeze(2)
part_v = F.linear(hidden_states, self.qkv_proj.weight.view(3, act_num_heads, self.head_dim, hidden_dim)[2][head // self.num_key_value_groups]).unsqueeze(2).transpose(1, 2)
part_k = apply_rotary_pos_emb_single(part_k.transpose(1, 2), cos, sin, position_ids)
if use_cache and past_key_value is not None:
k[:,head // self.num_key_value_groups] = part_k.cpu()
v[:,head // self.num_key_value_groups] = part_v.cpu()
part_k, part_v = past_key_value.get(part_k, part_v, self.layer_idx, head // self.num_key_value_groups, cache_kwargs)
if self.layer_idx >= self.starting_layer:
part_o = self.gather_last_q_vertical_slash_topk_v4(part_q, part_k, part_v, head)
else:
part_o = flash_attn_func(part_q, part_k, part_v.transpose(1, 2), 0.0, softmax_scale=None, causal=True).view(bsz, part_q.shape[1], self.head_dim)
attn_out[:, :, head, :] = part_o
if use_cache and past_key_value is not None:
past_key_value.update(k, v, self.layer_idx, cache_kwargs)
torch.matmul(attn_out.view(bsz, q_len, hidden_dim), self.o_proj.weight.T, out=hidden_states)
torch.cuda.empty_cache()
return (hidden_states, None, past_key_value)
return forward
def minference_with_snapkv_forward():
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
**kwargs,
):
self.init_minference_parameters()
self.ne_inf = torch.finfo(hidden_states.dtype).min
init_snapkv(self)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len
if self.kv_seq_len != 0:
kv_seq_len += self.kv_seq_len
else:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
else:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
global ROPE_TYPE
if ROPE_TYPE is None:
ROPE_TYPE = "seq_len" in inspect.signature(self.rotary_emb.forward).parameters
if ROPE_TYPE:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
else:
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster
self.kv_seq_len = kv_seq_len # [SnapKV] register kv_seq_len
key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
else:
self.kv_seq_len += q_len
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if self.layer_idx >= self.starting_layer:
assert query_states.size(1) == key_states.size(1) == value_states.size(1)
output = torch.empty_like(query_states)
for head in range(query_states.size(1)):
q = query_states[:, head, :, :].unsqueeze(1)
k = key_states[:, head, :, :].unsqueeze(1)
v = value_states[:, head, :, :].unsqueeze(1)
output[:, head:head + 1] = self.gather_last_q_vertical_slash_topk_v4(q, k, v, head)
attn_output = output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
else:
output = torch.empty_like(query_states)
for head in range(query_states.size(1)):
q = query_states[:, head, :, :].unsqueeze(1)
k = key_states[:, head, :, :].unsqueeze(1)
v = value_states[:, head, :, :].unsqueeze(1)
if is_flash_attn_2_available():
attn_output = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q.shape[2], self.head_dim)
else:
attn_output = gather_qkv(q, k, v, attention_mask)
output[:, head:head + 1] = attn_output
attn_output = output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return forward
def gather_last_q_vertical_slash_topk_vllm(self, q, k, v, head_id):
kv_seq_len = k.size(2)
head_dim = q.size(-1)
def vertical_and_slash_kernel(q, k, v, vertical_size, slash_size):
vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
last_q = min(64, q_len)
qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k)
qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK[...,-last_q:,-last_q:], qk[:, :, :, -last_q:], -torch.inf)
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
vertical = qk.sum(-2, keepdim=True)
vertical[...,:30] = torch.inf
vertical_topk = torch.topk(vertical, vertical_size, -1).indices
slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
slash[...,-100:] = torch.inf
slash_topk = slash
slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices
return vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)
def block_sparse_kernel(q, k, v, vertical_size=None, slash_size=None):
topk = 100
return block_sparse_attention(q, k, v, topk)
def dense(q, k, v, vertical_size=None, slash_size=None):
return flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, head_dim)
q_len = q.shape[2]
bsz = q.shape[0]
ty, vertical_size, slash_size, _ = self.best_pattern[head_id]
if q_len == 1:
return dense(q, k, v)
fc = {
"stream_llm": streaming_forward,
"vertical_and_slash": vertical_and_slash_kernel,
"block_sparse": block_sparse_kernel,
}[ty]
return fc(q, k, v, vertical_size, slash_size)
def minference_vllm_forward(
pattern_config
):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
kv_scale: float,
layer_idx: int,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
self.best_pattern = {int(ii): jj for ii, jj in pattern_config[layer_idx].items()}
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
slen, num_key_value_heads, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, None, :, :].expand(slen, n_rep, num_key_value_heads, head_dim)
return hidden_states.reshape(slen, num_key_value_heads * n_rep, head_dim)
def minference_prefill_func(
q, k, v,
):
# (seq_len, num_heads, head_size)
if q.size(-2) != k.size(-2):
k = repeat_kv(k, q.size(-2) // k.size(-2))
v = repeat_kv(v, q.size(-2) // v.size(-2))
output = torch.empty_like(q)
for head in range(q.size(-2)):
q_head = q[:, head, :].unsqueeze(1)
k_head = k[:, head, :].unsqueeze(1)
v_head = v[:, head, :].unsqueeze(1)
# (1, seq_len, num_heads, head_size)
q_head = q_head[None, ...]
k_head = k_head[None, ...]
v_head = v_head[None, ...]
q_head = q_head.transpose(1, 2)
k_head = k_head.transpose(1, 2)
v_head = v_head.transpose(1, 2)
out = self.gather_last_q_vertical_slash_topk_vllm(q_head, k_head, v_head, head)
out = out.transpose(1, 2).squeeze(0).contiguous()
output[:, head:head+1, :] = out
return output
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# (seq_len, num_heads, head_size)
# out = flash_attn_varlen_func(
# q=query,
# k=key,
# v=value,
# cu_seqlens_q=prefill_meta.seq_start_loc,
# cu_seqlens_k=prefill_meta.seq_start_loc,
# max_seqlen_q=prefill_meta.max_prompt_len,
# max_seqlen_k=prefill_meta.max_prompt_len,
# softmax_scale=self.scale,
# causal=True,
# window_size=self.sliding_window,
# alibi_slopes=self.alibi_slopes,
# )
out = minference_prefill_func(query, key, value)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.prompt_lens_tensor,
prefill_meta.context_lens,
prefill_meta.max_subquery_len,
self.alibi_slopes,
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.context_lens,
decode_meta.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
return forward