from typing import Optional, Tuple, List import math import torch from torch import Tensor from torch.nn import Linear, Module from torch.nn import functional as F from torch.nn.modules.linear import NonDynamicallyQuantizableLinear class MultiheadAttention(Module): __constants__ = ["batch_first"] bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] def __init__( self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, linear1_cls=Linear, linear2_cls=Linear, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = False self.num_heads = num_heads self.dropout = dropout self.batch_first = batch_first self.head_dim = embed_dim // num_heads self.num_heads = num_heads assert ( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" self.k_proj = Linear(self.kdim, embed_dim) self.v_proj = Linear(self.kdim, embed_dim) self.q_proj = Linear(self.kdim, embed_dim) self.out_proj = NonDynamicallyQuantizableLinear( embed_dim, embed_dim, bias=bias, **factory_kwargs ) self.add_zero_attn = add_zero_attn self.scaling = self.head_dim**-0.5 def __setstate__(self, state): # Support loading old MultiheadAttention checkpoints generated by v1.1.0 if "_qkv_same_embed_dim" not in state: state["_qkv_same_embed_dim"] = True super(MultiheadAttention, self).__setstate__(state) def forward( self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True, ) -> Tuple[Tensor, Optional[Tensor]]: # T,B,C B, T, C = query.size() q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) attn_weights = q @ k.transpose(-2, -1) # B, nh, T, T if attn_mask is not None: # attn_mask is inf # attn_mask = attn_mask.unsqueeze(0) # attn_weights += attn_mask if torch.is_floating_point(attn_mask): # print(attn_weights.size(), attn_mask.size()) attn_weights += attn_mask.unsqueeze(0).unsqueeze(1) else: attn_weights = attn_weights.masked_fill(attn_mask, float('-inf')) if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(B, self.num_heads, T, T) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1) .unsqueeze(2) .to(torch.bool), float("-inf"), ) attn_weights_float = F.softmax(attn_weights, dim=-1) attn = attn_weights_float @ v y = attn.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side y = self.out_proj(y) return y, attn_weights def infer(self, x: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True, past_kv = None, use_cache = False): # print("debug:"+str(x.size())) B, T, C = x.size() q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) q *= self.scaling # k = k.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs) # q = q.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs) # v = v.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs) k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) if past_kv is not None: past_key = past_kv[0] past_value = past_kv[1] k = torch.cat((past_key, k), dim=-2) v = torch.cat((past_value, v), dim=-2) FULL_T = k.shape[-2] if use_cache is True: present = (k, v) else: present = None # print(q.size(), k.size()) attn_weights = q @ k.transpose(-2, -1) # print(attn_mask.size()) attn_weights = attn_weights.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf')) # if key_padding_mask is not None: # # don't attend to padding symbols # attn_weights = attn_weights.view(B, self.num_heads, T, T) # attn_weights = attn_weights.view(B, -1, self.num_heads, T, T) # attn_weights = attn_weights.masked_fill( # key_padding_mask.unsqueeze(1) # .unsqueeze(2) # .unsqueeze(3) # .to(torch.bool), # float("-inf"), # ) attn_weights_float = F.softmax(attn_weights, dim=-1, ) # attn_weights = attn_weights_float.type_as(attn_weights) # attn = torch.bmm(attn_weights, v) attn = attn_weights_float @ v y = attn.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side y = self.out_proj(y) return (y, present)