Spaces:
Sleeping
Sleeping
import math | |
import hydra | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input | |
from flash_attn.flash_blocksparse_attn_interface import ( | |
convert_blockmask, | |
flash_blocksparse_attn_func, | |
) | |
class FlashBlocksparseAttention(nn.Module): | |
"""Implement the scaled dot product attention with softmax. | |
Arguments | |
--------- | |
softmax_temp: The temperature to use for the softmax attention. | |
(default: 1/sqrt(d_keys) where d_keys is computed at | |
runtime) | |
attention_dropout: The dropout rate to apply to the attention | |
(default: 0.1) | |
""" | |
def __init__( | |
self, | |
sparsity_config, | |
softmax_temp=None, | |
attention_dropout=0.0, | |
max_seq_length=2048, | |
device=None, | |
dtype=None, | |
): | |
super().__init__() | |
self.sparsity_config = hydra.utils.instantiate(sparsity_config) | |
self.softmax_temp = softmax_temp | |
self.dropout_p = attention_dropout | |
# initialize sparse layout and register as buffer | |
max_seq_length = ((max_seq_length + 256 - 1) // 256) * 256 | |
layout = self.sparsity_config.make_layout(max_seq_length) | |
self.register_buffer("layout", layout) | |
blockmask_converted = convert_blockmask(self.layout, causal=False) | |
self.register_buffer("blockmask_converted", blockmask_converted) | |
# logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}') | |
def forward( | |
self, | |
qkv, | |
attn_mask=None, | |
key_padding_mask=None, | |
causal=False, | |
cu_seqlens=None, | |
max_s=None, | |
need_weights=False, | |
convert_mask=True, | |
): | |
"""Implements the multihead softmax attention. | |
Arguments | |
--------- | |
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None | |
attn_mask: An implementation of BaseMask that encodes where each | |
query can attend to | |
key_padding_mask: An implementation of BaseMask that encodes how | |
many query each sequence in the batch consists of | |
""" | |
assert not need_weights | |
assert attn_mask is None | |
assert qkv.dtype == torch.float16 | |
assert qkv.is_cuda | |
if cu_seqlens is None: | |
batch_size = qkv.shape[0] | |
seqlen = qkv.shape[1] | |
# Convert mask to take a subset | |
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256 | |
assert seqlen_rounded // 16 <= self.layout.shape[0], ( | |
seqlen_rounded // 256 <= self.layout.shape[1] | |
) | |
blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256] | |
if key_padding_mask is None: | |
qkv = rearrange(qkv, "b s ... -> (b s) ...") | |
max_s = seqlen | |
cu_seqlens = torch.arange( | |
0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device | |
) | |
output = flash_blocksparse_attn_func( | |
qkv, | |
cu_seqlens, | |
blockmask, | |
self.dropout_p if self.training else 0.0, | |
max_s, | |
softmax_scale=self.softmax_temp, | |
causal=causal, | |
) | |
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) | |
else: | |
key_padding_mask_bool = key_padding_mask.bool_matrix | |
nheads = qkv.shape[-2] | |
x = rearrange(qkv, "b s three h d -> b s (three h d)") | |
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool) | |
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) | |
output_unpad = flash_blocksparse_attn_func( | |
x_unpad, | |
cu_seqlens, | |
blockmask, | |
self.dropout_p if self.training else 0.0, | |
max_s, | |
softmax_scale=self.softmax_temp, | |
causal=causal, | |
) | |
output = rearrange( | |
pad_input( | |
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen | |
), | |
"b s (h d) -> b s h d", | |
h=nheads, | |
) | |
else: | |
assert max_s is not None | |
seqlen = max_s | |
# Convert mask to take a subset | |
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256 | |
assert seqlen_rounded // 16 <= self.layout.shape[0], ( | |
seqlen_rounded // 256 <= self.layout.shape[1] | |
) | |
blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256] | |
if convert_mask: | |
output = flash_blocksparse_attn_func( | |
qkv, | |
cu_seqlens, | |
blockmask, | |
self.dropout_p if self.training else 0.0, | |
max_s, | |
softmax_scale=self.softmax_temp, | |
causal=causal, | |
) | |
else: | |
output = flash_blocksparse_attn_func( | |
qkv, | |
cu_seqlens, | |
self.blockmask_converted, | |
self.dropout_p if self.training else 0.0, | |
max_s, | |
softmax_scale=self.softmax_temp, | |
causal=causal, | |
convert_mask=False, | |
) | |
return output, None | |
class FlashBlocksparseMHA(nn.Module): | |
def __init__( | |
self, | |
embed_dim, | |
num_heads, | |
sparsity_config, | |
bias=True, | |
batch_first=True, | |
attention_dropout=0.0, | |
causal=False, | |
max_seq_length=2048, | |
device=None, | |
dtype=None, | |
**kwargs, | |
) -> None: | |
assert batch_first | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.causal = causal | |
self.num_heads = num_heads | |
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" | |
self.head_dim = self.embed_dim // num_heads | |
assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64" | |
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) | |
self.inner_attn = FlashBlocksparseAttention( | |
sparsity_config, | |
attention_dropout=attention_dropout, | |
max_seq_length=max_seq_length, | |
**factory_kwargs, | |
) | |
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) | |
def forward( | |
self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, need_weights=False | |
): | |
qkv = self.Wqkv(x) | |
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads) | |
context, attn_weights = self.inner_attn( | |
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal | |
) | |
return self.out_proj(rearrange(context, "b s h d -> b s (h d)")), attn_weights | |