|
import math |
|
import copy |
|
import time |
|
import random |
|
import spacy |
|
import numpy as np |
|
import os |
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
import torch.optim as optim |
|
|
|
from model.sublayers import ( |
|
MultiHeadAttention, |
|
PositionalEncoding, |
|
PositionwiseFeedForward, |
|
Embedding) |
|
|
|
from model.encoder import Encoder |
|
from model.decoder import Decoder |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__(self, |
|
dk, |
|
dv, |
|
h, |
|
src_vocab_size, |
|
target_vocab_size, |
|
num_encoders, |
|
num_decoders, |
|
src_pad_idx, |
|
target_pad_idx, |
|
dim_multiplier = 4, |
|
pdropout=0.1, |
|
device = "cpu" |
|
): |
|
super().__init__() |
|
|
|
|
|
dmodel = dk*h |
|
|
|
self.src_embeddings = Embedding(src_vocab_size, dmodel) |
|
self.src_positional_encoding = PositionalEncoding( |
|
dmodel, |
|
max_seq_length = src_vocab_size, |
|
pdropout = pdropout |
|
) |
|
self.encoder = Encoder( |
|
dk, |
|
dv, |
|
h, |
|
num_encoders, |
|
dim_multiplier=dim_multiplier, |
|
pdropout=pdropout) |
|
|
|
|
|
self.target_embeddings = Embedding(target_vocab_size, dmodel) |
|
self.target_positional_encoding = PositionalEncoding( |
|
dmodel, |
|
max_seq_length = target_vocab_size, |
|
pdropout = pdropout |
|
) |
|
self.decoder = Decoder( |
|
dk, |
|
dv, |
|
h, |
|
num_decoders, |
|
dim_multiplier=4, |
|
pdropout=0.1) |
|
|
|
|
|
self.linear = nn.Linear(dmodel, target_vocab_size) |
|
|
|
self.device = device |
|
self.src_pad_idx = src_pad_idx |
|
self.target_pad_idx = target_pad_idx |
|
self.init_params() |
|
|
|
|
|
def init_params(self): |
|
""" |
|
xavier has tremendous impact! I didn't expect |
|
that the model's perf, with normalization layers, |
|
is so dependent on the choice of weight initialization. |
|
""" |
|
for name, p in self.named_parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def make_src_mask(self, src): |
|
""" |
|
Args: |
|
src: raw sequences with padding (batch_size, seq_length) |
|
src_pad_idx(int): index where the token need not be attended |
|
|
|
Returns: |
|
src_mask: mask for each sequence (batch_size, 1, 1, seq_length) |
|
""" |
|
batch_size = src.shape[0] |
|
|
|
|
|
src_mask = (src != self.src_pad_idx).view(batch_size, 1, 1, -1) |
|
return src_mask |
|
|
|
def make_target_mask(self, target): |
|
""" |
|
Args: |
|
target: raw sequences with padding (batch_size, seq_length) |
|
target_pad_idx(int): index where the token need not be attended |
|
|
|
Returns: |
|
target_mask: mask for each sequence (batch_size, 1, seq_length, seq_length) |
|
""" |
|
|
|
seq_length = target.shape[1] |
|
batch_size = target.shape[0] |
|
|
|
|
|
|
|
target_mask = (target != self.target_pad_idx).view(batch_size, 1, 1, -1) |
|
|
|
|
|
trg_sub_mask = torch.tril(torch.ones((seq_length, seq_length), device=self.device)).bool() |
|
|
|
|
|
target_mask = target_mask & trg_sub_mask |
|
|
|
return target_mask |
|
|
|
def forward( |
|
self, |
|
src_token_ids_batch, |
|
target_token_ids_batch): |
|
|
|
|
|
src_mask = self.make_src_mask( |
|
src_token_ids_batch) |
|
target_mask = self.make_target_mask( |
|
target_token_ids_batch) |
|
|
|
|
|
src_representations = self.src_embeddings(src_token_ids_batch) |
|
src_representations = self.src_positional_encoding(src_representations) |
|
|
|
target_representations = self.target_embeddings(target_token_ids_batch) |
|
target_representations = self.target_positional_encoding(target_representations) |
|
|
|
|
|
encoded_src = self.encoder(src_representations, src_mask) |
|
|
|
|
|
decoded_output = self.decoder( |
|
target_representations, |
|
encoded_src, |
|
target_mask, |
|
src_mask) |
|
|
|
|
|
out = self.linear(decoded_output) |
|
|
|
|
|
|
|
|
|
return out |
|
|
|
def count_parameters(model): |
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
if __name__ == "__main__": |
|
""" |
|
Following parameters are for Multi30K dataset |
|
""" |
|
dk = 32 |
|
dv = 32 |
|
h = 8 |
|
src_vocab_size = 7983 |
|
target_vocab_size = 5979 |
|
src_pad_idx = 2 |
|
target_pad_idx = 2 |
|
num_encoders = 3 |
|
num_decoders = 3 |
|
dim_multiplier = 4 |
|
pdropout=0.1 |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = Transformer( |
|
dk, |
|
dv, |
|
h, |
|
src_vocab_size, |
|
target_vocab_size, |
|
num_encoders, |
|
num_decoders, |
|
dim_multiplier, |
|
pdropout, |
|
device = device) |
|
if torch.cuda.is_available(): |
|
model.cuda() |
|
print(model) |
|
print(f'The model has {count_parameters(model):,} trainable parameters') |
|
|
|
|