xcczach's picture
Upload 73 files
35c1cfd verified
raw
history blame
6.76 kB
from typing import Optional, Tuple, List
import math
import torch
from torch import Tensor
from torch.nn import Linear, Module
from torch.nn import functional as F
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
class MultiheadAttention(Module):
__constants__ = ["batch_first"]
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
linear1_cls=Linear,
linear2_cls=Linear,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = False
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
self.num_heads = num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.k_proj = Linear(self.kdim, embed_dim)
self.v_proj = Linear(self.kdim, embed_dim)
self.q_proj = Linear(self.kdim, embed_dim)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
self.add_zero_attn = add_zero_attn
self.scaling = self.head_dim**-0.5
def __setstate__(self, state):
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
if "_qkv_same_embed_dim" not in state:
state["_qkv_same_embed_dim"] = True
super(MultiheadAttention, self).__setstate__(state)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
) -> Tuple[Tensor, Optional[Tensor]]:
# T,B,C
B, T, C = query.size()
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
attn_weights = q @ k.transpose(-2, -1) # B, nh, T, T
if attn_mask is not None:
# attn_mask is inf
# attn_mask = attn_mask.unsqueeze(0)
# attn_weights += attn_mask
if torch.is_floating_point(attn_mask):
# print(attn_weights.size(), attn_mask.size())
attn_weights += attn_mask.unsqueeze(0).unsqueeze(1)
else:
attn_weights = attn_weights.masked_fill(attn_mask, float('-inf'))
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(B, self.num_heads, T, T)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1)
.unsqueeze(2)
.to(torch.bool),
float("-inf"),
)
attn_weights_float = F.softmax(attn_weights, dim=-1)
attn = attn_weights_float @ v
y = attn.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
y = self.out_proj(y)
return y, attn_weights
def infer(self,
x: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
past_kv = None,
use_cache = False):
# print("debug:"+str(x.size()))
B, T, C = x.size()
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q *= self.scaling
# k = k.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs)
# q = q.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs)
# v = v.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs)
k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
if past_kv is not None:
past_key = past_kv[0]
past_value = past_kv[1]
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)
FULL_T = k.shape[-2]
if use_cache is True:
present = (k, v)
else:
present = None
# print(q.size(), k.size())
attn_weights = q @ k.transpose(-2, -1)
# print(attn_mask.size())
attn_weights = attn_weights.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf'))
# if key_padding_mask is not None:
# # don't attend to padding symbols
# attn_weights = attn_weights.view(B, self.num_heads, T, T)
# attn_weights = attn_weights.view(B, -1, self.num_heads, T, T)
# attn_weights = attn_weights.masked_fill(
# key_padding_mask.unsqueeze(1)
# .unsqueeze(2)
# .unsqueeze(3)
# .to(torch.bool),
# float("-inf"),
# )
attn_weights_float = F.softmax(attn_weights, dim=-1, )
# attn_weights = attn_weights_float.type_as(attn_weights)
# attn = torch.bmm(attn_weights, v)
attn = attn_weights_float @ v
y = attn.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
y = self.out_proj(y)
return (y, present)