Rzoro's picture
Add trained transformer model
0f53151
import math
import copy
import time
import random
import spacy
import numpy as np
import os
# torch packages
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torch.optim as optim
from model.sublayers import (
MultiHeadAttention,
PositionalEncoding,
PositionwiseFeedForward,
Embedding)
from model.encoder import Encoder
from model.decoder import Decoder
class Transformer(nn.Module):
def __init__(self,
dk,
dv,
h,
src_vocab_size,
target_vocab_size,
num_encoders,
num_decoders,
src_pad_idx,
target_pad_idx,
dim_multiplier = 4,
pdropout=0.1,
device = "cpu"
):
super().__init__()
# Reference page 5 chapter 3.2.2 Multi-head attention
dmodel = dk*h
# Modules required to build Encoder
self.src_embeddings = Embedding(src_vocab_size, dmodel)
self.src_positional_encoding = PositionalEncoding(
dmodel,
max_seq_length = src_vocab_size,
pdropout = pdropout
)
self.encoder = Encoder(
dk,
dv,
h,
num_encoders,
dim_multiplier=dim_multiplier,
pdropout=pdropout)
# Modules required to build Decoder
self.target_embeddings = Embedding(target_vocab_size, dmodel)
self.target_positional_encoding = PositionalEncoding(
dmodel,
max_seq_length = target_vocab_size,
pdropout = pdropout
)
self.decoder = Decoder(
dk,
dv,
h,
num_decoders,
dim_multiplier=4,
pdropout=0.1)
# Final output
self.linear = nn.Linear(dmodel, target_vocab_size)
# self.softmax = nn.Softmax(dim=-1)
self.device = device
self.src_pad_idx = src_pad_idx
self.target_pad_idx = target_pad_idx
self.init_params()
# This part wasn't mentioned in the paper, but it's super important!
def init_params(self):
"""
xavier has tremendous impact! I didn't expect
that the model's perf, with normalization layers,
is so dependent on the choice of weight initialization.
"""
for name, p in self.named_parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def make_src_mask(self, src):
"""
Args:
src: raw sequences with padding (batch_size, seq_length)
src_pad_idx(int): index where the token need not be attended
Returns:
src_mask: mask for each sequence (batch_size, 1, 1, seq_length)
"""
batch_size = src.shape[0]
# assign 1 to tokens that need attended to and 0 to padding tokens,
# then add 2 dimensions
src_mask = (src != self.src_pad_idx).view(batch_size, 1, 1, -1)
return src_mask
def make_target_mask(self, target):
"""
Args:
target: raw sequences with padding (batch_size, seq_length)
target_pad_idx(int): index where the token need not be attended
Returns:
target_mask: mask for each sequence (batch_size, 1, seq_length, seq_length)
"""
seq_length = target.shape[1]
batch_size = target.shape[0]
# assign True to tokens that need attended to and
# False to padding tokens, then add 2 dimensions
target_mask = (target != self.target_pad_idx).view(batch_size, 1, 1, -1) # (batch_size, 1, 1, seq_length)
# generate subsequent mask
trg_sub_mask = torch.tril(torch.ones((seq_length, seq_length), device=self.device)).bool() # (batch_size, 1, seq_length, seq_length)
# bitwise "and" operator | 0 & 0 = 0, 1 & 1 = 1, 1 & 0 = 0
target_mask = target_mask & trg_sub_mask
return target_mask
def forward(
self,
src_token_ids_batch,
target_token_ids_batch):
# create source and target masks
src_mask = self.make_src_mask(
src_token_ids_batch) # (batch_size, 1, 1, src_seq_length)
target_mask = self.make_target_mask(
target_token_ids_batch) # (batch_size, 1, trg_seq_length, trg_seq_length)
# Create embeddings
src_representations = self.src_embeddings(src_token_ids_batch)
src_representations = self.src_positional_encoding(src_representations)
target_representations = self.target_embeddings(target_token_ids_batch)
target_representations = self.target_positional_encoding(target_representations)
# Encode
encoded_src = self.encoder(src_representations, src_mask)
# Decode
decoded_output = self.decoder(
target_representations,
encoded_src,
target_mask,
src_mask)
# Post processing
out = self.linear(decoded_output)
# Don't use softmax as we are not comparing against softmaxed output while
# computing loss. We are comparing against linear outputs
# # Output
# out = self.softmax(out)
return out
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == "__main__":
"""
Following parameters are for Multi30K dataset
"""
dk = 32
dv = 32
h = 8
src_vocab_size = 7983
target_vocab_size = 5979
src_pad_idx = 2
target_pad_idx = 2
num_encoders = 3
num_decoders = 3
dim_multiplier = 4
pdropout=0.1
# print(111)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Transformer(
dk,
dv,
h,
src_vocab_size,
target_vocab_size,
num_encoders,
num_decoders,
dim_multiplier,
pdropout,
device = device)
if torch.cuda.is_available():
model.cuda()
print(model)
print(f'The model has {count_parameters(model):,} trainable parameters')