|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
|
|
|
|
|
|
|
|
class MultiHeadedAttention(nn.Module): |
|
""" |
|
Multi-Head Attention module from "Attention is All You Need" |
|
|
|
Implementation modified from OpenNMT-py. |
|
https://github.com/OpenNMT/OpenNMT-py |
|
""" |
|
|
|
def __init__(self, num_heads: int, size: int, dropout: float = 0.1): |
|
""" |
|
Create a multi-headed attention layer. |
|
:param num_heads: the number of heads |
|
:param size: model size (must be divisible by num_heads) |
|
:param dropout: probability of dropping a unit |
|
""" |
|
super().__init__() |
|
|
|
assert size % num_heads == 0 |
|
|
|
self.head_size = head_size = size // num_heads |
|
self.model_size = size |
|
self.num_heads = num_heads |
|
|
|
self.k_layer = nn.Linear(size, num_heads * head_size) |
|
self.v_layer = nn.Linear(size, num_heads * head_size) |
|
self.q_layer = nn.Linear(size, num_heads * head_size) |
|
|
|
self.output_layer = nn.Linear(size, size) |
|
self.softmax = nn.Softmax(dim=-1) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, k: Tensor, v: Tensor, q: Tensor, mask: Tensor = None): |
|
""" |
|
Computes multi-headed attention. |
|
|
|
:param k: keys [B, M, D] with M being the sentence length. |
|
:param v: values [B, M, D] |
|
:param q: query [B, M, D] |
|
:param mask: optional mask [B, 1, M] or [B, M, M] |
|
:return: |
|
""" |
|
batch_size = k.size(0) |
|
num_heads = self.num_heads |
|
|
|
|
|
k = self.k_layer(k) |
|
v = self.v_layer(v) |
|
q = self.q_layer(q) |
|
|
|
|
|
k = k.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) |
|
v = v.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) |
|
q = q.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) |
|
|
|
|
|
q = q / math.sqrt(self.head_size) |
|
|
|
|
|
scores = torch.matmul(q, k.transpose(2, 3)) |
|
|
|
|
|
|
|
|
|
if mask is not None: |
|
scores = scores.masked_fill(~mask.unsqueeze(1), float('-inf')) |
|
|
|
|
|
attention = self.softmax(scores) |
|
attention = self.dropout(attention) |
|
|
|
|
|
|
|
|
|
|
|
context = torch.matmul(attention, v) |
|
context = context.transpose(1, 2).contiguous().view( |
|
batch_size, -1, num_heads * self.head_size) |
|
|
|
|
|
output = self.output_layer(context) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
class PositionwiseFeedForward(nn.Module): |
|
""" |
|
Position-wise Feed-forward layer |
|
Projects to ff_size and then back down to input_size. |
|
""" |
|
|
|
def __init__(self, input_size, ff_size, dropout=0.1): |
|
""" |
|
Initializes position-wise feed-forward layer. |
|
:param input_size: dimensionality of the input. |
|
:param ff_size: dimensionality of intermediate representation |
|
:param dropout: |
|
""" |
|
super().__init__() |
|
self.layer_norm = nn.LayerNorm(input_size, eps=1e-6) |
|
self.pwff_layer = nn.Sequential( |
|
nn.Linear(input_size, ff_size), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(ff_size, input_size), |
|
nn.Dropout(dropout), |
|
) |
|
|
|
def forward(self, x): |
|
x_norm = self.layer_norm(x) |
|
return self.pwff_layer(x_norm) + x |
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
""" |
|
Pre-compute position encodings (PE). |
|
In forward pass, this adds the position-encodings to the |
|
input for as many time steps as necessary. |
|
|
|
Implementation based on OpenNMT-py. |
|
https://github.com/OpenNMT/OpenNMT-py |
|
""" |
|
|
|
def __init__(self, size: int = 0, max_len: int = 5000): |
|
""" |
|
Positional Encoding with maximum length max_len |
|
:param size: |
|
:param max_len: |
|
:param dropout: |
|
""" |
|
if size % 2 != 0: |
|
raise ValueError("Cannot use sin/cos positional encoding with " |
|
"odd dim (got dim={:d})".format(size)) |
|
pe = torch.zeros(max_len, size) |
|
position = torch.arange(0, max_len).unsqueeze(1) |
|
div_term = torch.exp((torch.arange(0, size, 2, dtype=torch.float) * |
|
-(math.log(10000.0) / size))) |
|
pe[:, 0::2] = torch.sin(position.float() * div_term) |
|
pe[:, 1::2] = torch.cos(position.float() * div_term) |
|
pe = pe.unsqueeze(0) |
|
super().__init__() |
|
self.register_buffer('pe', pe) |
|
self.dim = size |
|
|
|
def forward(self, emb): |
|
"""Embed inputs. |
|
Args: |
|
emb (FloatTensor): Sequence of word vectors |
|
``(seq_len, batch_size, self.dim)`` |
|
""" |
|
|
|
return emb + self.pe[:, :emb.size(1)] |
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
""" |
|
One Transformer encoder layer has a Multi-head attention layer plus |
|
a position-wise feed-forward layer. |
|
""" |
|
|
|
def __init__(self, |
|
size: int = 0, |
|
ff_size: int = 0, |
|
num_heads: int = 0, |
|
dropout: float = 0.1): |
|
""" |
|
A single Transformer layer. |
|
:param size: |
|
:param ff_size: |
|
:param num_heads: |
|
:param dropout: |
|
""" |
|
super().__init__() |
|
|
|
self.layer_norm = nn.LayerNorm(size, eps=1e-6) |
|
self.src_src_att = MultiHeadedAttention(num_heads, |
|
size, |
|
dropout=dropout) |
|
self.feed_forward = PositionwiseFeedForward(size, |
|
ff_size=ff_size, |
|
dropout=dropout) |
|
self.dropout = nn.Dropout(dropout) |
|
self.size = size |
|
|
|
|
|
def forward(self, x: Tensor, mask: Tensor) -> Tensor: |
|
""" |
|
Forward pass for a single transformer encoder layer. |
|
First applies layer norm, then self attention, |
|
then dropout with residual connection (adding the input to the result), |
|
and then a position-wise feed-forward layer. |
|
|
|
:param x: layer input |
|
:param mask: input mask |
|
:return: output tensor |
|
""" |
|
x_norm = self.layer_norm(x) |
|
h = self.src_src_att(x_norm, x_norm, x_norm, mask) |
|
h = self.dropout(h) + x |
|
o = self.feed_forward(h) |
|
return o |
|
|
|
|
|
class TransformerDecoderLayer(nn.Module): |
|
""" |
|
Transformer decoder layer. |
|
|
|
Consists of self-attention, source-attention, and feed-forward. |
|
""" |
|
|
|
def __init__(self, |
|
size: int = 0, |
|
ff_size: int = 0, |
|
num_heads: int = 0, |
|
dropout: float = 0.1): |
|
""" |
|
Represents a single Transformer decoder layer. |
|
|
|
It attends to the source representation and the previous decoder states. |
|
|
|
:param size: model dimensionality |
|
:param ff_size: size of the feed-forward intermediate layer |
|
:param num_heads: number of heads |
|
:param dropout: dropout to apply to input |
|
""" |
|
super().__init__() |
|
self.size = size |
|
|
|
self.trg_trg_att = MultiHeadedAttention(num_heads, |
|
size, |
|
dropout=dropout) |
|
self.src_trg_att = MultiHeadedAttention(num_heads, |
|
size, |
|
dropout=dropout) |
|
|
|
self.feed_forward = PositionwiseFeedForward(size, |
|
ff_size=ff_size, |
|
dropout=dropout) |
|
|
|
self.x_layer_norm = nn.LayerNorm(size, eps=1e-6) |
|
self.dec_layer_norm = nn.LayerNorm(size, eps=1e-6) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, |
|
x: Tensor = None, |
|
memory: Tensor = None, |
|
src_mask: Tensor = None, |
|
trg_mask: Tensor = None) -> Tensor: |
|
""" |
|
Forward pass of a single Transformer decoder layer. |
|
|
|
:param x: inputs |
|
:param memory: source representations |
|
:param src_mask: source mask |
|
:param trg_mask: target mask (so as to not condition on future steps) |
|
:return: output tensor |
|
""" |
|
|
|
x_norm = self.x_layer_norm(x) |
|
h1 = self.trg_trg_att(x_norm, x_norm, x_norm, mask=trg_mask) |
|
h1 = self.dropout(h1) + x |
|
|
|
|
|
h1_norm = self.dec_layer_norm( |
|
h1) |
|
h2 = self.src_trg_att(memory, memory, h1_norm, mask=src_mask) |
|
|
|
|
|
o = self.feed_forward(self.dropout(h2) + h1) |
|
|
|
return o |
|
|