Spaces:
Runtime error
Runtime error
from typing import Optional, Tuple, List | |
import math | |
import torch | |
from torch import Tensor | |
from torch.nn import Linear, Module | |
from torch.nn import functional as F | |
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear | |
class MultiheadAttention(Module): | |
__constants__ = ["batch_first"] | |
bias_k: Optional[torch.Tensor] | |
bias_v: Optional[torch.Tensor] | |
def __init__( | |
self, | |
embed_dim, | |
num_heads, | |
dropout=0.0, | |
bias=True, | |
add_bias_kv=False, | |
add_zero_attn=False, | |
kdim=None, | |
vdim=None, | |
batch_first=False, | |
linear1_cls=Linear, | |
linear2_cls=Linear, | |
device=None, | |
dtype=None, | |
) -> None: | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super(MultiheadAttention, self).__init__() | |
self.embed_dim = embed_dim | |
self.kdim = kdim if kdim is not None else embed_dim | |
self.vdim = vdim if vdim is not None else embed_dim | |
self._qkv_same_embed_dim = False | |
self.num_heads = num_heads | |
self.dropout = dropout | |
self.batch_first = batch_first | |
self.head_dim = embed_dim // num_heads | |
self.num_heads = num_heads | |
assert ( | |
self.head_dim * num_heads == self.embed_dim | |
), "embed_dim must be divisible by num_heads" | |
self.k_proj = Linear(self.kdim, embed_dim) | |
self.v_proj = Linear(self.kdim, embed_dim) | |
self.q_proj = Linear(self.kdim, embed_dim) | |
self.out_proj = NonDynamicallyQuantizableLinear( | |
embed_dim, embed_dim, bias=bias, **factory_kwargs | |
) | |
self.add_zero_attn = add_zero_attn | |
self.scaling = self.head_dim**-0.5 | |
def __setstate__(self, state): | |
# Support loading old MultiheadAttention checkpoints generated by v1.1.0 | |
if "_qkv_same_embed_dim" not in state: | |
state["_qkv_same_embed_dim"] = True | |
super(MultiheadAttention, self).__setstate__(state) | |
def forward( | |
self, | |
query: Tensor, | |
key: Tensor, | |
value: Tensor, | |
key_padding_mask: Optional[Tensor] = None, | |
need_weights: bool = True, | |
attn_mask: Optional[Tensor] = None, | |
average_attn_weights: bool = True, | |
) -> Tuple[Tensor, Optional[Tensor]]: | |
# T,B,C | |
B, T, C = query.size() | |
q = self.q_proj(query) | |
k = self.k_proj(key) | |
v = self.v_proj(value) | |
q *= self.scaling | |
k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) | |
q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) | |
v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) | |
attn_weights = q @ k.transpose(-2, -1) # B, nh, T, T | |
if attn_mask is not None: | |
# attn_mask is inf | |
# attn_mask = attn_mask.unsqueeze(0) | |
# attn_weights += attn_mask | |
if torch.is_floating_point(attn_mask): | |
# print(attn_weights.size(), attn_mask.size()) | |
attn_weights += attn_mask.unsqueeze(0).unsqueeze(1) | |
else: | |
attn_weights = attn_weights.masked_fill(attn_mask, float('-inf')) | |
if key_padding_mask is not None: | |
# don't attend to padding symbols | |
attn_weights = attn_weights.view(B, self.num_heads, T, T) | |
attn_weights = attn_weights.masked_fill( | |
key_padding_mask.unsqueeze(1) | |
.unsqueeze(2) | |
.to(torch.bool), | |
float("-inf"), | |
) | |
attn_weights_float = F.softmax(attn_weights, dim=-1) | |
attn = attn_weights_float @ v | |
y = attn.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side | |
y = self.out_proj(y) | |
return y, attn_weights | |
def infer(self, | |
x: Tensor, | |
key_padding_mask: Optional[Tensor] = None, | |
need_weights: bool = True, | |
attn_mask: Optional[Tensor] = None, | |
average_attn_weights: bool = True, | |
past_kv = None, | |
use_cache = False): | |
# print("debug:"+str(x.size())) | |
B, T, C = x.size() | |
q = self.q_proj(x) | |
k = self.k_proj(x) | |
v = self.v_proj(x) | |
q *= self.scaling | |
# k = k.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs) | |
# q = q.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs) | |
# v = v.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs) | |
k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) | |
q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) | |
v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) | |
if past_kv is not None: | |
past_key = past_kv[0] | |
past_value = past_kv[1] | |
k = torch.cat((past_key, k), dim=-2) | |
v = torch.cat((past_value, v), dim=-2) | |
FULL_T = k.shape[-2] | |
if use_cache is True: | |
present = (k, v) | |
else: | |
present = None | |
# print(q.size(), k.size()) | |
attn_weights = q @ k.transpose(-2, -1) | |
# print(attn_mask.size()) | |
attn_weights = attn_weights.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf')) | |
# if key_padding_mask is not None: | |
# # don't attend to padding symbols | |
# attn_weights = attn_weights.view(B, self.num_heads, T, T) | |
# attn_weights = attn_weights.view(B, -1, self.num_heads, T, T) | |
# attn_weights = attn_weights.masked_fill( | |
# key_padding_mask.unsqueeze(1) | |
# .unsqueeze(2) | |
# .unsqueeze(3) | |
# .to(torch.bool), | |
# float("-inf"), | |
# ) | |
attn_weights_float = F.softmax(attn_weights, dim=-1, ) | |
# attn_weights = attn_weights_float.type_as(attn_weights) | |
# attn = torch.bmm(attn_weights, v) | |
attn = attn_weights_float @ v | |
y = attn.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side | |
y = self.out_proj(y) | |
return (y, present) |