Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,460 Bytes
246c106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import torch
from torch import nn
from xformers.ops import LowerTriangularMask, memory_efficient_attention, unbind
import os
XFORMERS_DISABLED = os.environ.get("XFORMERS_DISABLED", "false").lower() == "true"
class BasicSelfAttention(nn.Module):
def __init__(
self,
num_heads: int,
d_model: int,
qkv_bias: bool = False,
proj_bias: bool = True,
qk_norm: bool = True,
use_mup: bool = True,
attn_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# Scaling by 8 to be equal when head_dim=64
self.scale = 8/self.head_dim if use_mup else self.head_dim**-0.5
self.qkv = nn.Linear(d_model, d_model * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model, bias=proj_bias)
self.qk_norm = qk_norm
if self.qk_norm:
# qk normalization https://arxiv.org/pdf/2302.05442
# Note that LN is done in fp32, so they have to be
self.norm = nn.LayerNorm(self.head_dim, eps=1e-05)
def forward(self, x: torch.Tensor, causal: bool = False) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
if self.qk_norm:
q = self.norm(q)
k = self.norm(k)
# LN done in float32, cast back to bf16
q = q.to(dtype=v.dtype)
k = k.to(dtype=v.dtype)
q *= self.scale
attn = q @ k.transpose(-2, -1)
if causal:
mask_value = -torch.finfo(attn.dtype).max
i, j = attn.shape[-2:]
mask = ~torch.tril(torch.ones(i, j)).bool().to(attn.device)
attn = attn.masked_fill(mask, mask_value)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class BasicCrossAttention(nn.Module):
def __init__(
self,
num_heads: int,
d_model: int,
k_model: int,
qkv_bias: bool = False,
proj_bias: bool = True,
qk_norm: bool = True,
use_mup: bool = True,
attn_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# Scaling by 8 to be equal when head_dim=64
self.scale = 8/self.head_dim if use_mup else self.head_dim**-0.5
# self.qkv = nn.Linear(d_model, d_model * 3, bias=qkv_bias)
self.to_q = nn.Linear(d_model, d_model, bias=qkv_bias)
self.to_k = nn.Linear(d_model, d_model, bias=qkv_bias)
self.to_v = nn.Linear(d_model, d_model, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model, bias=proj_bias)
self.qk_norm = qk_norm
if self.qk_norm:
# qk normalization https://arxiv.org/pdf/2302.05442
# Note that LN is done in fp32, so they have to be
self.norm = nn.LayerNorm(self.head_dim, eps=1e-05)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False) -> torch.Tensor:
"""
q: (b s) t c
k: (b) t c
"""
B, N, C = q.shape
k = k.repeat(B // len(k), 1, 1)
v = v.repeat(B // len(v), 1, 1)
k = k[:, :q.shape[1]]
v = v[:, :q.shape[1]]
B, M, _ = k.shape
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# q, k, v = qkv[0], qkv[1], qkv[2]
q = self.to_q(q).reshape(B, N, self.num_heads, self.head_dim)
k = self.to_k(k).reshape(B, M, self.num_heads, self.head_dim)
v = self.to_v(v).reshape(B, M, self.num_heads, self.head_dim)
if self.qk_norm:
q = self.norm(q)
k = self.norm(k)
# LN done in float32, cast back to bf16
q = q.to(dtype=v.dtype)
k = k.to(dtype=v.dtype)
q *= self.scale
attn = q @ k.transpose(-2, -1)
if causal:
mask_value = -torch.finfo(attn.dtype).max
i, j = attn.shape[-2:]
mask = ~torch.tril(torch.ones(i, j)).bool().to(attn.device)
attn = attn.masked_fill(mask, mask_value)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class MemoryEfficientAttention(BasicSelfAttention):
# NOTE: Mem-eff attention from xformers is actually Flash Attention 2
def forward(self, x: torch.Tensor, causal: bool = False) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
q, k, v = unbind(qkv, 2)
if self.qk_norm:
q = self.norm(q)
k = self.norm(k)
# LN done in float32, cast back to bf16
q = q.to(dtype=v.dtype)
k = k.to(dtype=v.dtype)
attn_bias = LowerTriangularMask() if causal else None
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, scale=self.scale) #
x = x.reshape([B, N, C])
x = self.proj(x)
return x
if XFORMERS_DISABLED:
SelfAttention = BasicSelfAttention
else:
SelfAttention = MemoryEfficientAttention |